# 1. 参数配置

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import BertTokenizer, BertModel, GPT2Tokenizer, GPT2LMHeadModel
from torch.distributions import MultivariateNormal, Categorical

import matplotlib.pyplot as plt

import time

import random


In [None]:
## wandb login
# import wandb
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# secret_value_0 = user_secrets.get_secret("wandb_key")

# wandb.login(key ='75d6ae3a989e4630c051b36e121c363339e0b5a5')

|Dateset|0%|25%|50%|75%|
|---|---|---|---|---|
|Banking|0+77|19+58|39+38|58+19|
|Clinc150|0+150|38+112|75+75|113+37|
|StackOverflow|0+20|5+15|10+10|15+5|

In [None]:
## Config
class config:
    seed = 200
    
    ## train parameters
    BATCH_SIZE = 20
    
    EPOCHES = 100
    num_warmup_rate=0.05
#     n_clusters = 140
    model_name = 'VAE_semi'
    dataset_name = 'banking' # clinc / stackoverflow
    n = 19

In [None]:
random.seed(config.seed)

# 2. 数据准备

In [None]:
## Load data

df = pd.read_csv('/home/zhaojinyue/workspace/intent-cluster/archive/data/banking/train.tsv', sep='\t')

df_te = pd.read_csv('/home/zhaojinyue/workspace/intent-cluster/archive/data/banking/test.tsv', sep='\t')

# df_dev = pd.read_csv(f'/kaggle/input/intent-dataset/data/{dataset_name}/dev.tsv', sep='\t')

In [None]:
num_classes = df['label'].nunique()#*2

In [None]:
label_mapping = {v:i for i,v in enumerate(df['label'].unique())}
df['label_num'] = df['label'].map(label_mapping)
df_te['label_num'] = df_te['label'].map(label_mapping)

In [None]:
tokenizer = BertTokenizer.from_pretrained('/home/zhaojinyue/workspace/intent-cluster/bert')

class CustomDataset(Dataset):
    def __init__(self, dataframe, num_classes=20):
        self.data = dataframe
        self.num_classes = num_classes
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        text = self.data.loc[index, 'text']
        encoded_input = tokenizer(text, return_tensors='pt', add_special_tokens=True, max_length=128, padding='max_length')
        label = self.data.loc[index, 'label_num']
        one_hot_label = F.one_hot(torch.tensor(label), num_classes=self.num_classes)
        inputs_ids = encoded_input['input_ids'].squeeze(0)
        attention_mask = encoded_input['attention_mask'].squeeze(0)
        return inputs_ids, attention_mask, one_hot_label

In [None]:
labeled_intents = random.sample(list(label_mapping.keys()), config.n)
print(labeled_intents)
labeled_idx = df['label'].isin(labeled_intents)
df_labeled = df[labeled_idx]
df_unlabeled = df[~labeled_idx]

df_labeled.reset_index(drop=True, inplace=True)
df_unlabeled.reset_index(drop=True, inplace=True)


In [None]:
# dataset = CustomDataset(df, num_classes)
# val_dataset = CustomDataset(df_te, num_classes)
labeled_dataset = CustomDataset(df_labeled, num_classes)
contrast_dataset = CustomDataset(df_unlabeled, num_classes)
val_dataset = CustomDataset(df_te, num_classes)

In [None]:
batch_size = config.BATCH_SIZE  # 批量大小
shuffle = True  # 打乱数据
labeled_dataloader = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True)
contrast_dataloader = DataLoader(contrast_dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size*2)

# 3.模型定义

In [None]:
# 加载BERT和GPT-2
bert_tokenizer = BertTokenizer.from_pretrained('/home/zhaojinyue/workspace/intent-cluster/bert')
bert_model = BertModel.from_pretrained('/home/zhaojinyue/workspace/intent-cluster/bert')

class VAE_DEC(nn.Module):
    def __init__(self, bert_model, num_classes):
        super(VAE_DEC, self).__init__()
        self.bert_encoder = bert_model
        self.latent_dim = bert_model.config.hidden_size
        
        self.fc_mu = nn.Linear(self.latent_dim, self.latent_dim)
        self.bn_mu = nn.BatchNorm1d(self.latent_dim)  # 添加批量归一化层
        
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(self.latent_dim, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
                
    def forward(self, input_ids, attention_mask):
        outputs = self.bert_encoder(input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state[:, 0, :]  # 使用[CLS] token表示
        hidden = self.relu(self.fc1(hidden_state))
        logits = self.softmax(self.fc2(hidden))
        logits = self.dropout(logits)
        return hidden_state, logits

bert_model = BertModel.from_pretrained('/home/zhaojinyue/workspace/intent-cluster/bert')

model = VAE_DEC(bert_model, num_classes = num_classes )

# model_dict = model.state_dict()

# pretrained_dict = torch.load('/home/zhaojinyue/workspace/intent-cluster/vae-semi/model_epoch2.pth',map_location=torch.device('cpu'))

# pretrained_dict = {key: value for key, value in pretrained_dict.items() if key in model_dict }
# model_dict.update(pretrained_dict)
# model.load_state_dict(model_dict)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)

# 3.损失函数和优化器

In [None]:
## wandb.init()
def class2dict(f):
    return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

# wandb.init(project='NLP-intent-VAE-supuervised', 
   # name=f'{config.dataset_name}_n{config.n}_seed{config.seed}_pretrain_ablation',
   #  config=class2dict(config),
    # group=config.model_name,
   #  job_type="train",
    # anonymous="must")

In [None]:
## SupConLoss 有监督的对比学习Loss
class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, contrast_mode='all'):
        super(SupConLoss, self).__init__()
        self.contrast_mode = contrast_mode

    def forward(self, features, labels=None, mask=None, temperature = 0.07, device = None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

In [None]:
## Loss
each_epoch_steps = len(labeled_dataloader)
epoches = config.EPOCHES

criterion = nn.CrossEntropyLoss()
contrast_criterion = SupConLoss()

In [None]:
## Optimizer
optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-6)

## Scheduler
from transformers import get_cosine_schedule_with_warmup
num_train_steps = config.EPOCHES*each_epoch_steps
num_warmup_steps = int(num_train_steps*config.num_warmup_rate)
scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps, num_cycles=0.5)

# 4.训练模型

In [None]:
from sklearn.cluster import KMeans
from tqdm.notebook import tqdm

feat_dim = 768

In [None]:
## evalute
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics import adjusted_rand_score
# from sklearn.metrics import accuracy_score

from scipy.optimize import linear_sum_assignment

def calculate_acc(true_labels, pred_labels):
    true_labels = np.array(true_labels, dtype=int)
    pred_labels = np.array(pred_labels, dtype=int)
    # 构建混淆矩阵
    max_label = max(max(true_labels), max(pred_labels)) + 1
    confusion_matrix = np.zeros((max_label, max_label), dtype=int)
    
    for t, p in zip(true_labels, pred_labels):
        confusion_matrix[t, p] += 1
    
    # 使用匈牙利算法找到最佳匹配
    row_ind, col_ind = linear_sum_assignment(-confusion_matrix)
    
    # 计算准确率
    accuracy = confusion_matrix[row_ind, col_ind].sum() / len(true_labels)
    return accuracy

def evalute(true_labels, cluster_labels):
    # 计算归一化互信息
    nmi = normalized_mutual_info_score(true_labels, cluster_labels, average_method='arithmetic')
    # 计算调整兰德系数
    ari = adjusted_rand_score(true_labels, cluster_labels)
    # 计算聚类准确率
    acc = calculate_acc(true_labels, cluster_labels)

    return nmi, ari, acc

# nmi, ari, acc = evalute(true_labels, cluster_labels)

# https://arxiv.org/pdf/2304.07699
# ===========================
# | NMI    | ARI    | ACC   |
# | 87.41  | 69.54  | 78.36 |

# semi-supervised

In [None]:
from itertools import cycle

In [None]:
## Traing
best_loss = 1_000

for epoch in range(epoches):
    ## train
    model.train()
    st = time.time()
    losses = []

    for step, (batch_labeled, batch_unlabeled) in enumerate(zip(cycle(labeled_dataloader), contrast_dataloader)):
        input_ids_labeled, attention_mask_labeled, labels = batch_labeled
        input_ids_labeled, attention_mask_labeled, labels = input_ids_labeled.to(device, dtype=torch.long), attention_mask_labeled.to(device, dtype=torch.long), labels.to(device, dtype=torch.float)

        input_ids_unlabeled, attention_mask_unlabeled, unlabels = batch_unlabeled
        input_ids_unlabeled, attention_mask_unlabeled = input_ids_labeled.to(device, dtype=torch.long), attention_mask_labeled.to(device, dtype=torch.long)

        unlabels_ = unlabels.argmax(dim=1)
        unlabels_ids = torch.full_like(unlabels_,-1)
        unlabels_ids = unlabels_ids.to(device)
        labels_ids = labels.argmax(dim=1)

        input_ids = torch.cat((input_ids_labeled, input_ids_unlabeled))
        attention_mask = torch.cat((attention_mask_labeled, attention_mask_unlabeled))
        label_ids = torch.cat((labels_ids, unlabels_ids))

        batch_size = input_ids.shape[0]
        labels_expand = label_ids.expand(batch_size, batch_size)
        mask = torch.eq(labels_expand, labels_expand.T).long()
        mask[label_ids == -1, :] = 0

        logits_mask = torch.scatter(
            mask,
            0,
            torch.arange(batch_size).unsqueeze(0).to(device),
            1
        )

        ## 标注数据 labeled data
        z, logits = model(input_ids_labeled, attention_mask_labeled)
        loss_ce = criterion(logits, labels)
        ## 非标注数据 unlabeled data 对比学习
        # 正样本对；负样本对 最终会体现在损失函数里边
        z_a,logits_a = model(input_ids, attention_mask)
        z_b,logits_b = model(input_ids, attention_mask)

        norm_z_a = F.normalize(z_a)
        norm_z_b = F.normalize(z_b)

        contrastive_zs = torch.cat((norm_z_a.unsqueeze(1), norm_z_b.unsqueeze(1)), dim = 1)

        loss_contrast = contrast_criterion(contrastive_zs, mask = logits_mask, temperature = 0.07, device = device)
        loss = loss_contrast + loss_ce

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()        
        losses.append(loss.item())
        scheduler.step()
        # 更新学习率
    
    v_losses = []
    v_accs = []
    ## validation
    model.eval()
    prs=[]
    gts=[]
    for input_ids, attention_mask, labels in val_dataloader:
        input_ids, attention_mask, labels = input_ids.to(device, dtype=torch.long), attention_mask.to(device, dtype=torch.long), labels.to(device, dtype=torch.float)
        with torch.no_grad():
            z,logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            v_losses.append(loss.item())
            pr = [p.argmax().cpu().numpy() for p in logits]
            gt = [l.argmax().cpu().numpy() for l in labels]
            prs.extend(pr)
            gts.extend(gt)
            acc_tmp = [p==g for p,g in zip(pr,gt)]
            v_accs.append(sum(acc_tmp)/len(acc_tmp))
    ed = time.time()
    lr_cur = optimizer.param_groups[0]['lr']
    nmi, ari, acc = evalute(gts, prs)
    print(f'[Epoch {epoch+1}/{epoches}] Train Loss: {np.mean(losses):.2f}, Test Loss: {np.mean(v_losses):.2f}, Test acc: {np.mean(v_accs):.2f}, Test evaluate: {nmi}, {ari}, {acc}, lr: {lr_cur}, time: {ed-st:.0f}s')
    # wandb
    # wandb.log({
    #     f"Epoch": epoch+1,
    #     f"avg_train_loss": np.mean(losses),
    #     f"avg_test_loss": np.mean(v_losses),
    #     f"avg_test_acc": np.mean(v_accs),
    #     f"nmi": nmi,
    #     f"ari": ari,
    #     f"acc": acc,
    #     f"lr": lr_cur
    # })
    if best_loss > np.mean(losses):
        best_loss = np.mean(losses)
        model_path = '/home/zhaojinyue/workspace/intent-cluster/NEW/model_best_tr.pth'
        torch.save(model.state_dict(), model_path)
        
    if epoch % 10==0:
        model_ep_path = f'/home/zhaojinyue/workspace/intent-cluster/AN19/S200model_epoch{epoch}.pth'
        torch.save(model.state_dict(), model_ep_path)