# 1. 参数配置

In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import BertTokenizer, BertModel, GPT2Tokenizer, GPT2LMHeadModel
from torch.distributions import MultivariateNormal, Categorical

import matplotlib.pyplot as plt

import time

import random

In [2]:
## wandb login
# import wandb
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# secret_value_0 = user_secrets.get_secret("wandb_key")

# wandb.login(key = secret_value_0)

|Dateset|0%|25%|50%|75%|
|---|---|---|---|---|
|Banking|0+77|19+58|39+38|58+19|
|Clinc150|0+150|38+112|75+75|113+37|
|StackOverflow|0+20|5+15|10+10|15+5|

In [3]:
## Config
class config:
    seed = 200
    ## model parameters
    max_seq_length = 128
    re_prob = 0.2
    
    ## train parameters
    BATCH_SIZE = 50
    EPOCHES = 30
    num_warmup_rate=0.05

    model_name = 'VAE_semi'
    dataset_name = 'CLINC' # clinc / stackoverflow
    n = 75

In [4]:
random.seed(config.seed)

# 2. 数据准备

In [5]:
## Load data

df = pd.read_csv('/home/zhaojinyue/workspace/intent-cluster/archive/data/clinc/train.tsv', sep='\t')

df_te = pd.read_csv('/home/zhaojinyue/workspace/intent-cluster/archive/data/clinc/test.tsv', sep='\t')

# df_dev = pd.read_csv(f'/kaggle/input/intent-dataset/data/{dataset_name}/dev.tsv', sep='\t')

In [6]:
num_classes = df['label'].nunique()#*2

In [7]:
label_mapping = {v:i for i,v in enumerate(df['label'].unique())}
df['label_num'] = df['label'].map(label_mapping)
df_te['label_num'] = df_te['label'].map(label_mapping)

In [8]:
tokenizer = BertTokenizer.from_pretrained('/home/zhaojinyue/workspace/intent-cluster/bert')

class CustomDataset(Dataset):
    def __init__(self, dataframe, num_classes=20):
        self.data = dataframe
        self.num_classes = num_classes
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        text = self.data.loc[index, 'text']
        encoded_input = tokenizer(text, return_tensors='pt', add_special_tokens=True, max_length=config.max_seq_length, padding='max_length')
        label = self.data.loc[index, 'label_num']
        one_hot_label = F.one_hot(torch.tensor(label), num_classes=self.num_classes)
        inputs_ids = encoded_input['input_ids'].squeeze(0)
        attention_mask = encoded_input['attention_mask'].squeeze(0)
        return inputs_ids, attention_mask, one_hot_label

In [9]:
labeled_intents = random.sample(list(label_mapping.keys()), config.n)
labeled_idx = df['label'].isin(labeled_intents)
df_labeled = df[labeled_idx]
df_unlabeled = df[~labeled_idx]

df_labeled.reset_index(drop=True, inplace=True)
df_unlabeled.reset_index(drop=True, inplace=True)
print(labeled_intents)

['oracle', 'matlab', 'scala', 'osx', 'wordpress', 'excel', 'drupal', 'svn', 'haskell', 'hibernate', 'bash', 'ajax', 'qt', 'cocoa', 'apache']


In [10]:
labeled_dataset = CustomDataset(df_labeled, num_classes)
contrast_dataset = CustomDataset(df_unlabeled, num_classes)

train_dataset = CustomDataset(df, num_classes)
val_dataset = CustomDataset(df_te, num_classes)

In [11]:
batch_size = config.BATCH_SIZE  # 批量大小
shuffle = True  # 打乱数据
labeled_dataloader = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=shuffle)
contrast_dataloader = DataLoader(contrast_dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size*2)

# 3.模型定义

In [12]:
# 加载BERT和GPT-2
bert_model = BertModel.from_pretrained('/home/zhaojinyue/workspace/intent-cluster/bert')

class VAE_DEC(nn.Module):
    def __init__(self, bert_model, num_classes):
        super(VAE_DEC, self).__init__()
        self.bert_encoder = bert_model
        self.latent_dim = bert_model.config.hidden_size
        
        self.fc_mu = nn.Linear(self.latent_dim, self.latent_dim)
        self.fc_logvar = nn.Linear(self.latent_dim, self.latent_dim)
        self.bn_mu = nn.BatchNorm1d(self.latent_dim)  # 添加批量归一化层
        
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(self.latent_dim, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)        
        
    def encode(self, input_ids, attention_mask=None):
        outputs = self.bert_encoder(input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state[:, 0, :]  # 使用[CLS] token表示
        mu = self.bn_mu(self.fc_mu(hidden_state))
        return mu
    
        
    def forward(self, input_ids, attention_mask):
        mu = self.encode(input_ids, attention_mask)
#         mu = self.dropout(mu)
        hidden = self.relu(self.fc1(mu))
        logits = self.softmax(self.fc2(hidden))
        logits = self.dropout(logits)
        return mu, logits

bert_model = BertModel.from_pretrained('/home/zhaojinyue/workspace/intent-cluster/bert')

model = VAE_DEC(bert_model, num_classes = num_classes )

model_dict = model.state_dict()

# pretrained_dict = torch.load('/kaggle/input/vae-semi-pretrain-model/model_epoch70.pth',map_location=torch.device('cpu'))
pretrained_dict = torch.load('/home/zhaojinyue/workspace/intent-cluster/N15/s200model_epoch90.pth',map_location=torch.device('cpu'))

pretrained_dict = {key: value for key, value in pretrained_dict.items() if key in model_dict }
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)

Some weights of the model checkpoint at /home/zhaojinyue/workspace/intent-cluster/bert were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at /home/zhaojinyue/workspace/intent-cluster/bert were not used when initializing BertMod

VAE_DEC(
  (bert_encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

# 3.损失函数和优化器

In [13]:
## wandb.init()
def class2dict(f):
    return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

# wandb.init(project='NLP-intent-VAE-supuervised', 
#     name=f'{config.model_name}_{config.dataset_name}_n={config.n}_seed{config.seed}_trianing',
#     config=class2dict(config),
#     group=config.model_name,
#     job_type="train",
#     anonymous="must")

In [14]:
## SupConLoss 有监督的对比学习Loss
class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, contrast_mode='all'):
        super(SupConLoss, self).__init__()
        self.contrast_mode = contrast_mode

    def forward(self, features, labels=None, mask=None, temperature = 0.07, device = None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

In [15]:
## Loss
l_each_epoch_steps = len(labeled_dataloader)
each_epoch_steps = len(train_dataloader)

epoches = config.EPOCHES

criterion = nn.CrossEntropyLoss()
contrast_criterion = SupConLoss()

In [16]:
## Optimizer
optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-6)
l_optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-6)

## Scheduler
from transformers import get_cosine_schedule_with_warmup
l_num_train_steps = config.EPOCHES*l_each_epoch_steps
num_train_steps = config.EPOCHES*each_epoch_steps
# num_warmup_steps = int(num_train_steps*config.num_warmup_rate)
l_num_warmup_steps = int(l_each_epoch_steps*5)
num_warmup_steps = int(each_epoch_steps*5)

scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps, num_cycles=0.5)

l_scheduler = get_cosine_schedule_with_warmup(
    l_optimizer, num_warmup_steps=l_num_warmup_steps, num_training_steps=l_num_train_steps, num_cycles=0.5)



# 4.训练模型

In [17]:
from sklearn.cluster import KMeans
from tqdm.notebook import tqdm

feat_dim = 768

In [18]:
def get_outputs(mode, model):    
    if mode == 'test':
        dataloader_ = val_dataloader
    elif mode == 'train':
        dataloader_ = train_dataloader
    model.eval()

    total_labels = torch.empty(0,dtype=torch.float).to(device) #创建空list
    total_preds = torch.empty(0,dtype=torch.long).to(device)
        
    total_features = torch.empty((0, feat_dim)).to(device)
    total_logits = torch.empty((0, num_classes)).to(device)
    
    for input_ids, attention_mask, labels in tqdm(dataloader_, desc="Iteration"):
        input_ids, attention_mask, labels = input_ids.to(device, dtype=torch.long), attention_mask.to(device, dtype=torch.long), labels.to(device, dtype=torch.float)

        with torch.set_grad_enabled(False):
            feats,logits = model(input_ids, attention_mask)
                
            total_labels = torch.cat((total_labels, labels.argmax(axis=1)))
            total_features = torch.cat((total_features, feats))
            total_logits = torch.cat((total_logits, logits))
        
    feats = total_features.cpu().numpy()
    y_true = total_labels.cpu().numpy()

    total_probs = F.softmax(total_logits.detach(), dim=1)
    total_maxprobs, total_preds = total_probs.max(dim = 1)
    y_pred = total_preds.cpu().numpy()
        
    y_logits = total_logits.cpu().numpy()
        
    outputs = {
        'y_true': y_true,
        'y_pred': y_pred,
        'logits': y_logits,
        'feats': feats
    }
    return outputs

In [19]:
def clustering(model, centroids):
    outputs = get_outputs(mode = 'train', model = model)
    feats = outputs['feats']

    if centroids is None:
        km = KMeans(n_clusters=num_classes, random_state=config.seed, init='k-means++').fit(feats)
    else:
        km = KMeans(n_clusters=num_classes, random_state=config.seed, init=centroids).fit(feats)
    km_centroids, assign_labels = km.cluster_centers_, km.labels_
         
    pseudo_labels = assign_labels.astype(np.int64)
        
    return outputs, km_centroids, pseudo_labels

In [20]:
def random_token_erase( input_ids, attention_mask, re_prob):
    tokenizer.get_special_tokens_mask(input_ids[0], already_has_special_tokens=True)
    aug_input_ids = []
    aug_input_mask = []

    for inp_i, inp_m in zip(input_ids, attention_mask):

        special_tokens_mask = tokenizer.get_special_tokens_mask(inp_i, already_has_special_tokens=True)
        sent_tokens_inds = np.where(np.array(special_tokens_mask) == 0)[0]
        inds = np.arange(len(sent_tokens_inds))
        masked_inds = np.random.choice(inds, size = int(len(inds) * re_prob), replace = False)
        sent_masked_inds = sent_tokens_inds[masked_inds]

        inp_i = np.delete(inp_i.cpu(), sent_masked_inds)
        inp_i = F.pad(inp_i, (0, config.max_seq_length - len(inp_i)), 'constant', 0)

        inp_m = np.delete(inp_m.cpu(), sent_masked_inds)
        inp_m = F.pad(inp_m, (0, config.max_seq_length - len(inp_m)), 'constant', 0)

        aug_input_ids.append(inp_i)
        aug_input_mask.append(inp_m)

    aug_input_ids = torch.stack(aug_input_ids, dim=0)
    aug_input_mask = torch.stack(aug_input_mask, dim=0)
    
    return aug_input_ids, aug_input_mask

In [21]:
## evalute
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics import adjusted_rand_score
# from sklearn.metrics import accuracy_score

from scipy.optimize import linear_sum_assignment

def calculate_acc(true_labels, pred_labels):
    true_labels = np.array(true_labels, dtype=int)
    pred_labels = np.array(pred_labels, dtype=int)
    # 构建混淆矩阵
    max_label = max(max(true_labels), max(pred_labels)) + 1
    confusion_matrix = np.zeros((max_label, max_label), dtype=int)
    
    for t, p in zip(true_labels, pred_labels):
        confusion_matrix[t, p] += 1
    
    # 使用匈牙利算法找到最佳匹配
    row_ind, col_ind = linear_sum_assignment(-confusion_matrix)
    
    # 计算准确率
    accuracy = confusion_matrix[row_ind, col_ind].sum() / len(true_labels)
    return accuracy

def evalute(true_labels, cluster_labels):
    # 计算归一化互信息
    nmi = normalized_mutual_info_score(true_labels, cluster_labels, average_method='arithmetic')
    # 计算调整兰德系数
    ari = adjusted_rand_score(true_labels, cluster_labels)
    # 计算聚类准确率
    acc = calculate_acc(true_labels, cluster_labels)

    return nmi, ari, acc

# https://arxiv.org/pdf/2304.07699
# ===========================
# | NMI    | ARI    | ACC   |
# | 87.41  | 69.54  | 78.36 |

In [22]:
def test_score(model, centroids):
        
    outputs = get_outputs(mode = 'test', model = model)
    feats = outputs['feats']
    y_true = outputs['y_true']

    km = KMeans(n_clusters = num_classes, random_state=config.seed, init = centroids).fit(feats) 
    y_pred = km.labels_

    nmi, ari, acc = evalute(y_true, y_pred)
    
    return nmi, ari, acc

In [23]:
## Traing
last_preds = None

centroids = None
best_loss = 1_000

for epoch in range(epoches):
    
    ## train 
    model.train()
    st = time.time()
    losses = []
    for input_ids, attention_mask, labels in labeled_dataloader:
        input_ids, attention_mask, labels = input_ids.to(device, dtype=torch.long), attention_mask.to(device, dtype=torch.long), labels.to(device, dtype=torch.float)

        feats_a,logits_a = model(input_ids, attention_mask)
        feats_b,logits_b = model(input_ids, attention_mask)

        norm_feats_a = F.normalize(feats_a)
        norm_feats_b = F.normalize(feats_b)
        
        constrastive_feats = torch.cat((norm_feats_a.unsqueeze(1), norm_feats_b.unsqueeze(1)), dim = 1)
        
        ## 计算对比学习Loss，使用的simCLR 的loss https://arxiv.org/pdf/2002.05709.pdf
        loss_contrast = contrast_criterion(constrastive_feats, labels = labels.argmax(axis=1), temperature = 0.07, device = device)
        
        loss = loss_contrast
        
        losses.append(loss.item())
        
        loss.backward()
        l_optimizer.step()
        l_optimizer.zero_grad()        
        l_scheduler.step()
    
    ## 更新质心和伪标签
    outputs, km_centroids, pseudo_labels = clustering(model, centroids)
    
    centroids = km_centroids
        
    ## 质心引导，对比学习训练（此处输入label为伪标签）
    losses2 = []
    model.train()
    for i, (input_ids, attention_mask, _) in enumerate(train_dataloader):
        labels_ = torch.tensor(pseudo_labels[batch_size*i:batch_size*(i+1)])
        labels_ = F.one_hot(labels_, num_classes=num_classes)
        input_ids, attention_mask, labels_ = input_ids.to(device, dtype=torch.long), attention_mask.to(device, dtype=torch.long), labels_.to(device, dtype=torch.float)
        
        # random eraze
        aug_input_ids_a, aug_input_mask_a = random_token_erase( input_ids, attention_mask, config.re_prob)
        aug_input_ids_b, aug_input_mask_b = random_token_erase( input_ids, attention_mask, config.re_prob)

        aug_input_ids_a, aug_input_mask_a = aug_input_ids_a.to(device, dtype=torch.long), aug_input_mask_a.to(device, dtype=torch.long)
        aug_input_ids_b, aug_input_mask_b = aug_input_ids_b.to(device, dtype=torch.long), aug_input_mask_b.to(device, dtype=torch.long)

        feats_a,logits_a = model(aug_input_ids_a, aug_input_mask_a)
        feats_b,logits_b = model(aug_input_ids_b, aug_input_mask_b)
    
        norm_feats_a = F.normalize(feats_a)
        norm_feats_b = F.normalize(feats_b)
        
        ## 计算对比学习
        constrastive_feats = torch.cat((norm_feats_a.unsqueeze(1), norm_feats_b.unsqueeze(1)), dim = 1)
        loss_contrast = contrast_criterion(constrastive_feats, labels = labels_.argmax(axis=1), temperature = 0.07, device = device)
        
        ## 伪标签与预测差异loss
        loss_ce = 0.5 * (criterion(logits_a, labels_) + criterion(logits_b, labels_)) 
                    
        loss = loss_contrast + loss_ce

        losses2.append(loss.item())
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
    
    nmi, ari, acc = test_score(model, centroids)
    lr_cur = optimizer.param_groups[0]['lr']

    ed = time.time()
    print(f'[Epoch {epoch+1}/{epoches}] Constrast Loss: {np.mean(losses):.2f}, Train Loss: {np.mean(losses2):.2f}, Test evaluate: {nmi}, {ari}, {acc}, lr: {lr_cur}, time: {ed-st:.0f}s')

    # wandb
    # wandb.log({
    #     f"Epoch": epoch+1,
    #     f"constrast_loss": np.mean(losses),
    #     f"train_loss": np.mean(losses2),
    #     f"nmi": nmi,
    #     f"ari": ari,
    #     f"acc": acc,
    #     f"lr": lr_cur
    # })
    
    # if best_loss > np.mean(losses2):
    #     best_loss = np.mean(losses2)
    #     model_path = '/home/zhaojinyue/workspace/intent-cluster/NEW/model_best_tr.pth'
    #     torch.save(model.state_dict(), model_path)
    # if epoch % 5==0:
    #     model_ep_path = f'model_epoch{epoch}.pth'
    #     torch.save(model.state_dict(), model_ep_path)    

Iteration:   0%|          | 0/240 [00:00<?, ?it/s]



Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 1/30] Constrast Loss: 1.96, Train Loss: 13.45, Test evaluate: 0.45058711910391114, 0.18957423944873622, 0.2455, lr: 4.000000000000001e-06, time: 544s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 2/30] Constrast Loss: 2.07, Train Loss: 8.49, Test evaluate: 0.6941992271759628, 0.5932834652926395, 0.7391666666666666, lr: 8.000000000000001e-06, time: 528s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 3/30] Constrast Loss: 1.99, Train Loss: 8.71, Test evaluate: 0.6632518731333004, 0.536045481010561, 0.6696666666666666, lr: 1.2e-05, time: 527s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 4/30] Constrast Loss: 1.97, Train Loss: 8.63, Test evaluate: 0.6565615806605265, 0.47364437755661876, 0.625, lr: 1.6000000000000003e-05, time: 534s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 5/30] Constrast Loss: 1.98, Train Loss: 8.34, Test evaluate: 0.6406327062188926, 0.46497481173809857, 0.5786666666666667, lr: 2e-05, time: 532s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 6/30] Constrast Loss: 1.97, Train Loss: 8.43, Test evaluate: 0.631711536694247, 0.4655077164101432, 0.553, lr: 1.9921147013144782e-05, time: 553s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 7/30] Constrast Loss: 1.98, Train Loss: 8.26, Test evaluate: 0.6782459918976959, 0.5848293706516591, 0.687, lr: 1.9685831611286312e-05, time: 526s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 8/30] Constrast Loss: 1.98, Train Loss: 8.23, Test evaluate: 0.6425097540177176, 0.5076875540816621, 0.5763333333333334, lr: 1.9297764858882516e-05, time: 526s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 9/30] Constrast Loss: 1.98, Train Loss: 8.37, Test evaluate: 0.6218747860937803, 0.4423603305537944, 0.5916666666666667, lr: 1.8763066800438638e-05, time: 526s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 10/30] Constrast Loss: 2.03, Train Loss: 8.24, Test evaluate: 0.7190222239876838, 0.6110362147004599, 0.7666666666666667, lr: 1.8090169943749477e-05, time: 527s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 11/30] Constrast Loss: 1.97, Train Loss: 8.33, Test evaluate: 0.6646950033872817, 0.5007292746956529, 0.5673333333333334, lr: 1.7289686274214116e-05, time: 526s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 12/30] Constrast Loss: 1.98, Train Loss: 8.22, Test evaluate: 0.740863711753232, 0.6534421866061915, 0.7606666666666667, lr: 1.63742398974869e-05, time: 526s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 13/30] Constrast Loss: 1.96, Train Loss: 8.25, Test evaluate: 0.699674677216466, 0.5732826644493673, 0.6865, lr: 1.5358267949789968e-05, time: 526s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 14/30] Constrast Loss: 1.98, Train Loss: 8.32, Test evaluate: 0.6935177979736564, 0.48956307155163764, 0.6125, lr: 1.4257792915650728e-05, time: 526s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 15/30] Constrast Loss: 1.98, Train Loss: 8.36, Test evaluate: 0.697880160443601, 0.5319920100740263, 0.6173333333333333, lr: 1.3090169943749475e-05, time: 526s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 16/30] Constrast Loss: 1.99, Train Loss: 8.30, Test evaluate: 0.7130014658975752, 0.5651083433501504, 0.6058333333333333, lr: 1.187381314585725e-05, time: 526s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 17/30] Constrast Loss: 2.00, Train Loss: 8.25, Test evaluate: 0.7229985166916197, 0.5651199387529315, 0.632, lr: 1.0627905195293135e-05, time: 526s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 18/30] Constrast Loss: 2.01, Train Loss: 8.22, Test evaluate: 0.7314470401236717, 0.5583546448156368, 0.6731666666666667, lr: 9.372094804706867e-06, time: 525s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 19/30] Constrast Loss: 2.02, Train Loss: 8.20, Test evaluate: 0.768721198577195, 0.6923484450681185, 0.7815, lr: 8.126186854142752e-06, time: 525s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 20/30] Constrast Loss: 2.03, Train Loss: 8.16, Test evaluate: 0.7670511661984728, 0.674573573738994, 0.7671666666666667, lr: 6.909830056250527e-06, time: 526s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 21/30] Constrast Loss: 2.05, Train Loss: 8.13, Test evaluate: 0.7560206714238727, 0.6326066375577694, 0.7266666666666667, lr: 5.742207084349274e-06, time: 525s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 22/30] Constrast Loss: 2.07, Train Loss: 8.10, Test evaluate: 0.7801643320492141, 0.6426786034448854, 0.7991666666666667, lr: 4.641732050210032e-06, time: 539s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 23/30] Constrast Loss: 2.09, Train Loss: 8.05, Test evaluate: 0.7900428191754704, 0.7037611801722766, 0.8228333333333333, lr: 3.625760102513103e-06, time: 539s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 24/30] Constrast Loss: 2.11, Train Loss: 8.01, Test evaluate: 0.7887153413671573, 0.7106324647752705, 0.8598333333333333, lr: 2.7103137257858867e-06, time: 534s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 25/30] Constrast Loss: 2.16, Train Loss: 7.99, Test evaluate: 0.7935650389922605, 0.7073390185545484, 0.8608333333333333, lr: 1.9098300562505266e-06, time: 525s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 26/30] Constrast Loss: 2.20, Train Loss: 7.96, Test evaluate: 0.7933257223319152, 0.7083545918654294, 0.861, lr: 1.2369331995613664e-06, time: 522s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 27/30] Constrast Loss: 2.24, Train Loss: 7.97, Test evaluate: 0.7927463759788442, 0.709049183029756, 0.8605, lr: 7.022351411174866e-07, time: 525s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 28/30] Constrast Loss: 2.32, Train Loss: 7.93, Test evaluate: 0.793463704910718, 0.7058440527035819, 0.8595, lr: 3.1416838871368925e-07, time: 525s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


[Epoch 29/30] Constrast Loss: 2.38, Train Loss: 7.90, Test evaluate: 0.7940110047268815, 0.7058224253813998, 0.8595, lr: 7.885298685522235e-08, time: 525s


Iteration:   0%|          | 0/240 [00:00<?, ?it/s]

  super()._check_params_vs_input(X, default_n_init=10)


Iteration:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 30/30] Constrast Loss: 2.42, Train Loss: 7.85, Test evaluate: 0.7939762311304148, 0.7058390783976195, 0.8595, lr: 0.0, time: 525s


  super()._check_params_vs_input(X, default_n_init=10)
