In [1]:
import argparse
import random
from sampler import data_sampler
from config import Config
import torch
from model.bert_encoder import Bert_Encoder
from model.dropout_layer import Dropout_Layer
from model.classifier import Softmax_Layer, Proto_Softmax_Layer
from data_loader import get_data_loader
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.cluster import KMeans
import collections
from copy import deepcopy
import os
# import wandb
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


def train_simple_model(config, encoder, dropout_layer, classifier, training_data, epochs, map_relid2tempid):
    data_loader = get_data_loader(config, training_data, shuffle=True)

    encoder.train()
    dropout_layer.train()
    classifier.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam([
        {'params': encoder.parameters(), 'lr': 0.00001},
        {'params': dropout_layer.parameters(), 'lr': 0.00001},
        {'params': classifier.parameters(), 'lr': 0.001}
    ])
    for epoch_i in range(epochs):
        losses = []
        for step, batch_data in enumerate(data_loader):
            optimizer.zero_grad()
            labels, _, tokens = batch_data
            labels = labels.to(config.device)
            labels = [map_relid2tempid[x.item()] for x in labels]
            labels = torch.tensor(labels).to(config.device)

            tokens = torch.stack([x.to(config.device) for x in tokens],dim=0)
            reps = encoder(tokens)
            reps, _ = dropout_layer(reps)
            logits = classifier(reps)
            loss = criterion(logits, labels)

            losses.append(loss.item())
            loss.backward()
            optimizer.step()
        print(f"loss is {np.array(losses).mean()}")


def compute_jsd_loss(m_input):
    # m_input: the result of m times dropout after the classifier.
    # size: m*B*C
    m = m_input.shape[0]
    mean = torch.mean(m_input, dim=0)
    jsd = 0
    for i in range(m):
        loss = F.kl_div(F.log_softmax(mean, dim=-1), F.softmax(m_input[i], dim=-1), reduction='none')
        loss = loss.sum()
        jsd += loss / m
    return jsd


def contrastive_loss(hidden, labels):

    logsoftmax = nn.LogSoftmax(dim=-1)

    return -(logsoftmax(hidden) * labels).sum() / labels.sum()


def construct_hard_triplets(output, labels, relation_data):
    positive = []
    negative = []
    pdist = nn.PairwiseDistance(p=2)
    for rep, label in zip(output, labels):
        positive_relation_data = relation_data[label.item()]
        negative_relation_data = []
        for key in relation_data.keys():
            if key != label.item():
                negative_relation_data.extend(relation_data[key])
        positive_distance = torch.stack([pdist(rep.cpu(), p) for p in positive_relation_data])
        negative_distance = torch.stack([pdist(rep.cpu(), n) for n in negative_relation_data])
        positive_index = torch.argmax(positive_distance)
        negative_index = torch.argmin(negative_distance)
        positive.append(positive_relation_data[positive_index.item()])
        negative.append(negative_relation_data[negative_index.item()])


    return positive, negative


def train_first(config, encoder, dropout_layer, classifier, training_data, epochs, map_relid2tempid, new_relation_data):
    data_loader = get_data_loader(config, training_data, shuffle=True)

    encoder.train()
    dropout_layer.train()
    classifier.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam([
        {'params': encoder.parameters(), 'lr': 0.00001},
        {'params': dropout_layer.parameters(), 'lr': 0.00001},
        {'params': classifier.parameters(), 'lr': 0.001}
    ])
    triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
    for epoch_i in range(epochs):
        losses = []
        for step, (labels, _, tokens) in enumerate(data_loader):

            optimizer.zero_grad()

            logits_all = []
            tokens = torch.stack([x.to(config.device) for x in tokens], dim=0)
            labels = labels.to(config.device)
            origin_labels = labels[:]
            labels = [map_relid2tempid[x.item()] for x in labels]
            labels = torch.tensor(labels).to(config.device)
            reps = encoder(tokens)
            outputs,_ = dropout_layer(reps)
            positives,negatives = construct_hard_triplets(outputs, origin_labels, new_relation_data)

            for _ in range(config.f_pass):
                output, output_embedding = dropout_layer(reps)
                logits = classifier(output)
                logits_all.append(logits)

            positives = torch.cat(positives, 0).to(config.device)
            negatives = torch.cat(negatives, 0).to(config.device)
            anchors = outputs
            logits_all = torch.stack(logits_all)
            m_labels = labels.expand((config.f_pass, labels.shape[0]))  # m,B
            loss1 = criterion(logits_all.reshape(-1, logits_all.shape[-1]), m_labels.reshape(-1))
            loss2 = compute_jsd_loss(logits_all)
            tri_loss = triplet_loss(anchors, positives, negatives)
            loss = loss1 + loss2 + tri_loss

            loss.backward()
            losses.append(loss.item())
            optimizer.step()
        print(f"loss is {np.array(losses).mean()}")


def train_mem_model(config, encoder, dropout_layer, classifier, training_data, epochs, map_relid2tempid, new_relation_data,
                prev_encoder, prev_dropout_layer, prev_classifier, prev_relation_index):
    data_loader = get_data_loader(config, training_data, shuffle=True)

    encoder.train()
    dropout_layer.train()
    classifier.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam([
        {'params': encoder.parameters(), 'lr': 0.00001},
        {'params': dropout_layer.parameters(), 'lr': 0.00001},
        {'params': classifier.parameters(), 'lr': 0.001}
    ])
    triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
    distill_criterion = nn.CosineEmbeddingLoss()
    T = config.kl_temp
    for epoch_i in range(epochs):
        losses = []
        for step, (labels, _, tokens) in enumerate(data_loader):

            optimizer.zero_grad()

            logits_all = []
            tokens = torch.stack([x.to(config.device) for x in tokens], dim=0)
            labels = labels.to(config.device)
            origin_labels = labels[:]
            labels = [map_relid2tempid[x.item()] for x in labels]
            labels = torch.tensor(labels).to(config.device)
            reps = encoder(tokens)
            normalized_reps_emb = F.normalize(reps.view(-1, reps.size()[1]), p=2, dim=1)
            outputs,_ = dropout_layer(reps)
            if prev_dropout_layer is not None:
                prev_outputs, _ = prev_dropout_layer(reps)
                positives,negatives = construct_hard_triplets(prev_outputs, origin_labels, new_relation_data)
            else:
                positives, negatives = construct_hard_triplets(outputs, origin_labels, new_relation_data)

            for _ in range(config.f_pass):
                output, output_embedding = dropout_layer(reps)
                logits = classifier(output)
                logits_all.append(logits)

            positives = torch.cat(positives, 0).to(config.device)
            negatives = torch.cat(negatives, 0).to(config.device)
            anchors = outputs
            logits_all = torch.stack(logits_all)
            m_labels = labels.expand((config.f_pass, labels.shape[0]))  # m,B
            loss1 = criterion(logits_all.reshape(-1, logits_all.shape[-1]), m_labels.reshape(-1))
            loss2 = compute_jsd_loss(logits_all)
            tri_loss = triplet_loss(anchors, positives, negatives)
            loss = loss1 + loss2 + tri_loss

            if prev_encoder is not None:
                prev_reps = prev_encoder(tokens).detach()
                normalized_prev_reps_emb = F.normalize(prev_reps.view(-1, prev_reps.size()[1]), p=2, dim=1)

                feature_distill_loss = distill_criterion(normalized_reps_emb, normalized_prev_reps_emb,
                                                         torch.ones(tokens.size(0)).to(
                                                             config.device))
                loss += feature_distill_loss

            if prev_dropout_layer is not None and prev_classifier is not None:
                prediction_distill_loss = None
                dropout_output_all = []
                prev_dropout_output_all = []
                for i in range(config.f_pass):
                    output, _ = dropout_layer(reps)
                    prev_output, _ = prev_dropout_layer(reps)
                    dropout_output_all.append(output)
                    prev_dropout_output_all.append(output)
                    pre_logits = prev_classifier(output).detach()

                    pre_logits = F.softmax(pre_logits.index_select(1, prev_relation_index) / T, dim=1)

                    log_logits = F.log_softmax(logits_all[i].index_select(1, prev_relation_index) / T, dim=1)
                    if i == 0:
                        prediction_distill_loss = -torch.mean(torch.sum(pre_logits * log_logits, dim=1))
                    else:
                        prediction_distill_loss += -torch.mean(torch.sum(pre_logits * log_logits, dim=1))

                prediction_distill_loss /= config.f_pass
                loss += prediction_distill_loss
                dropout_output_all = torch.stack(dropout_output_all)
                prev_dropout_output_all = torch.stack(prev_dropout_output_all)
                mean_dropout_output_all = torch.mean(dropout_output_all, dim=0)
                mean_prev_dropout_output_all = torch.mean(prev_dropout_output_all,dim=0)
                normalized_output = F.normalize(mean_dropout_output_all.view(-1, mean_dropout_output_all.size()[1]), p=2, dim=1)
                normalized_prev_output = F.normalize(mean_prev_dropout_output_all.view(-1, mean_prev_dropout_output_all.size()[1]), p=2, dim=1)
                hidden_distill_loss = distill_criterion(normalized_output, normalized_prev_output,
                                                         torch.ones(tokens.size(0)).to(
                                                             config.device))
                loss += hidden_distill_loss

            loss.backward()
            losses.append(loss.item())
            optimizer.step()
        print(f"loss is {np.array(losses).mean()}")




def batch2device(batch_tuple, device):
    ans = []
    for var in batch_tuple:
        if isinstance(var, torch.Tensor):
            ans.append(var.to(device))
        elif isinstance(var, list):
            ans.append(batch2device(var))
        elif isinstance(var, tuple):
            ans.append(tuple(batch2device(var)))
        else:
            ans.append(var)
    return ans


def evaluate_strict_model(config, encoder, dropout_layer, classifier, test_data, seen_relations, map_relid2tempid):
    data_loader = get_data_loader(config, test_data, batch_size=1)
    encoder.eval()
    dropout_layer.eval()
    classifier.eval()
    n = len(test_data)

    correct = 0
    for step, batch_data in enumerate(data_loader):
        labels, _, tokens = batch_data
        labels = labels.to(config.device)
        labels = [map_relid2tempid[x.item()] for x in labels]
        labels = torch.tensor(labels).to(config.device)

        tokens = torch.stack([x.to(config.device) for x in tokens],dim=0)
        reps = encoder(tokens)
        reps, _ = dropout_layer(reps)
        logits = classifier(reps)

        seen_relation_ids = [rel2id[relation] for relation in seen_relations]
        seen_relation_ids = [map_relid2tempid[relation] for relation in seen_relation_ids]
        seen_sim = logits[:,seen_relation_ids].cpu().data.numpy()
        max_smi = np.max(seen_sim,axis=1)

        label_smi = logits[:,labels].cpu().data.numpy()

        if label_smi >= max_smi:
            correct += 1

    return correct/n


def select_data(config, encoder, dropout_layer, relation_dataset):
    data_loader = get_data_loader(config, relation_dataset, shuffle=False, drop_last=False, batch_size=1)
    features = []
    encoder.eval()
    dropout_layer.eval()
    for step, batch_data in enumerate(data_loader):
        labels, _, tokens = batch_data
        tokens = torch.stack([x.to(config.device) for x in tokens],dim=0)
        with torch.no_grad():
            feature = dropout_layer(encoder(tokens))[1].cpu()
        features.append(feature)

    features = np.concatenate(features)
    num_clusters = min(config.num_protos, len(relation_dataset))
    distances = KMeans(n_clusters=num_clusters, random_state=0).fit_transform(features)

    memory = []
    for k in range(num_clusters):
        sel_index = np.argmin(distances[:, k])
        instance = relation_dataset[sel_index]
        memory.append(instance)
    return memory


def get_proto(config, encoder, dropout_layer, relation_dataset):
    data_loader = get_data_loader(config, relation_dataset, shuffle=False, drop_last=False, batch_size=1)
    features = []
    encoder.eval()
    dropout_layer.eval()
    for step, batch_data in enumerate(data_loader):
        labels, _, tokens = batch_data
        tokens = torch.stack([x.to(config.device) for x in tokens],dim=0)
        with torch.no_grad():
            feature = dropout_layer(encoder(tokens))[1]
        features.append(feature)
    features = torch.cat(features, dim=0)
    proto = torch.mean(features, dim=0, keepdim=True).cpu()
    standard = torch.sqrt(torch.var(features, dim=0)).cpu()
    return proto, standard


def generate_relation_data(protos, relation_standard):
    relation_data = {}
    relation_sample_nums = 10
    for id in protos.keys():
        relation_data[id] = []
        difference = np.random.normal(loc=0, scale=1, size=relation_sample_nums)
        for diff in difference:
            relation_data[id].append(protos[id] + diff * relation_standard[id])
    return relation_data


def generate_current_relation_data(config, encoder, dropout_layer, relation_dataset):
    data_loader = get_data_loader(config, relation_dataset, shuffle=False, drop_last=False, batch_size=1)
    relation_data = []
    encoder.eval()
    dropout_layer.eval()
    for step, batch_data in enumerate(data_loader):
        labels, _, tokens = batch_data
        tokens = torch.stack([x.to(config.device) for x in tokens],dim=0)
        with torch.no_grad():
            feature = dropout_layer(encoder(tokens))[1].cpu()
        relation_data.append(feature)
    return relation_data

from transformers import  BertTokenizer
def data_augmentation(config, encoder, train_data, prev_train_data):
    expanded_train_data = train_data[:]
    expanded_prev_train_data = prev_train_data[:]
    encoder.eval()
    all_data = train_data + prev_train_data
    tokenizer = BertTokenizer.from_pretrained(config.bert_path, additional_special_tokens=["[E11]", "[E12]", "[E21]", "[E22]"])
    entity_index = []
    entity_mention = []
    for sample in all_data:
        e11 = sample['tokens'].index(30522)
        e12 = sample['tokens'].index(30523)
        e21 = sample['tokens'].index(30524)
        e22 = sample['tokens'].index(30525)
        entity_index.append([e11,e12])
        entity_mention.append(sample['tokens'][e11+1:e12])
        entity_index.append([e21,e22])
        entity_mention.append(sample['tokens'][e21+1:e22])

    data_loader = get_data_loader(config, all_data, shuffle=False, drop_last=False, batch_size=1)
    features = []
    encoder.eval()
    for step, batch_data in enumerate(data_loader):
        labels, _, tokens = batch_data
        tokens = torch.stack([x.to(config.device) for x in tokens],dim=0)
        with torch.no_grad():
            feature = encoder(tokens)
        feature1, feature2 = torch.split(feature, [config.encoder_output_size,config.encoder_output_size], dim=1)
        features.append(feature1)
        features.append(feature2)
    features = torch.cat(features, dim=0)
    # similarity_matrix = F.cosine_similarity(features.unsqueeze(1), features.unsqueeze(0), dim=-1)
    similarity_matrix = []
    for i in range(len(features)):
        similarity_matrix.append([0]*len(features))

    for i in range(len(features)):
        for j in range(i,len(features)):
            similarity = F.cosine_similarity(features[i],features[j],dim=0)
            similarity_matrix[i][j] = similarity
            similarity_matrix[j][i] = similarity

    similarity_matrix = torch.tensor(similarity_matrix).to(config.device)
    zero = torch.zeros_like(similarity_matrix).to(config.device)
    diag = torch.diag_embed(torch.diag(similarity_matrix))
    similarity_matrix -= diag
    similarity_matrix = torch.where(similarity_matrix<0.95, zero, similarity_matrix)
    nonzero_index = torch.nonzero(similarity_matrix)
    expanded_train_count = 0

    for origin, replace in nonzero_index:
        sample_index = int(origin/2)
        sample = all_data[sample_index]
        if entity_mention[origin] == entity_mention[replace]:
            continue
        new_tokens = sample['tokens'][:entity_index[origin][0]+1] + entity_mention[replace] + sample['tokens'][entity_index[origin][1]:]
        if len(new_tokens) < config.max_length:
            new_tokens = new_tokens + [0]*(config.max_length-len(new_tokens))
        else:
            new_tokens = new_tokens[:config.max_length]

        new_sample = {
            'relation': sample['relation'],
            'neg_labels': sample['neg_labels'],
            'tokens': new_tokens
        }
        if sample_index < len(train_data) and expanded_train_count < 5 * len(train_data):
            expanded_train_data.append(new_sample)
            expanded_train_count += 1
        else:
            expanded_prev_train_data.append(new_sample)
    return expanded_train_data, expanded_prev_train_data


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class args:
    task = "fewrel"
    shot = 5
    config = 'config.ini'
    


In [3]:
config = Config(args.config)
config.device = torch.device(config.device)
config.n_gpu = torch.cuda.device_count()
config.batch_size_per_step = int(config.batch_size / config.gradient_accumulation_steps)

config.task = args.task
config.shot = args.shot
config.step1_epochs = 5
config.step2_epochs = 15
config.step3_epochs = 20
config.temperature = 0.08

if config.task == "FewRel":
    config.relation_file = "data/fewrel/relation_name.txt"
    config.rel_index = "data/fewrel/rel_index.npy"
    config.rel_feature = "data/fewrel/rel_feature.npy"
    config.rel_des_file = "data/fewrel/relation_description.txt"
    config.num_of_relation = 80
    if config.shot == 5:
        config.rel_cluster_label = "data/fewrel/CFRLdata_10_100_10_5/rel_cluster_label_0.npy"
        config.training_file = "data/fewrel/CFRLdata_10_100_10_5/train_0.txt"
        config.valid_file = "data/fewrel/CFRLdata_10_100_10_5/valid_0.txt"
        config.test_file = "data/fewrel/CFRLdata_10_100_10_5/test_0.txt"
    elif config.shot == 10:
        config.rel_cluster_label = "data/fewrel/CFRLdata_10_100_10_10/rel_cluster_label_0.npy"
        config.training_file = "data/fewrel/CFRLdata_10_100_10_10/train_0.txt"
        config.valid_file = "data/fewrel/CFRLdata_10_100_10_10/valid_0.txt"
        config.test_file = "data/fewrel/CFRLdata_10_100_10_10/test_0.txt"
    else:
        config.rel_cluster_label = "data/fewrel/CFRLdata_10_100_10_2/rel_cluster_label_0.npy"
        config.training_file = "data/fewrel/CFRLdata_10_100_10_2/train_0.txt"
        config.valid_file = "data/fewrel/CFRLdata_10_100_10_2/valid_0.txt"
        config.test_file = "data/fewrel/CFRLdata_10_100_10_2/test_0.txt"
else:
    config.relation_file = "data/tacred/relation_name.txt"
    config.rel_index = "data/tacred/rel_index.npy"
    config.rel_feature = "data/tacred/rel_feature.npy"
    config.num_of_relation = 41
    if config.shot == 5:
        config.rel_cluster_label = "data/tacred/CFRLdata_10_100_10_5/rel_cluster_label_0.npy"
        config.training_file = "data/tacred/CFRLdata_10_100_10_5/train_0.txt"
        config.valid_file = "data/tacred/CFRLdata_10_100_10_5/valid_0.txt"
        config.test_file = "data/tacred/CFRLdata_10_100_10_5/test_0.txt"
    else:
        config.rel_cluster_label = "data/tacred/CFRLdata_10_100_10_10/rel_cluster_label_0.npy"
        config.training_file = "data/tacred/CFRLdata_10_100_10_10/train_0.txt"
        config.valid_file = "data/tacred/CFRLdata_10_100_10_10/valid_0.txt"
        config.test_file = "data/tacred/CFRLdata_10_100_10_10/test_0.txt"


In [4]:
sampler = data_sampler(config=config, seed=config.seed+100)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


[7 6 5 2 1 0 3 4]


In [5]:
data  = []
for steps, (training_data, valid_data, test_data, current_relations, historic_test_data, seen_relations) in enumerate(sampler):
            print(current_relations)
            data.append((training_data, valid_data, test_data, current_relations, historic_test_data, seen_relations))    

['person countries of residence', 'organization top members employees', 'organization member of', 'person origin', 'person title', 'organization country of headquarters']
['person stateorprovinces of residence', 'person date of death', 'organization number of employees members', 'person alternate names', 'person spouse']
['person date of birth', 'person stateorprovince of birth', 'person parents', 'person employee of', 'person stateorprovince of death']
['person cities of residence', 'person schools attended', 'person country of death', 'person children', 'person charges']
['organization subsidiaries', 'organization parents', 'organization alternate names', 'organization city of headquarters', 'person siblings']
['person country of birth', 'organization website', 'organization shareholders', 'organization dissolved', 'organization founded by']
['person cause of death', 'organization political religious affiliation', 'organization stateorprovince of headquarters', 'person other family',

In [6]:
train_data = []

In [7]:
for j in range(len(data)):
    train_data = []
    for i in range(len(list(data[j][0].values()))):
        train_data.extend(list(data[j][0].values())[i])
    cnt = 0
    for x in train_data:
        if 30524 not in x['tokens']:
            cnt +=1
            print(x['tokens'])
    if cnt > 0:
        print(j)
        print(cnt)

[101, 2268, 1011, 5757, 1011, 6021, 2102, 2692, 2549, 1024, 5354, 1024, 2184, 4012, 1028, 2626, 1024, 9587, 14945, 25353, 10936, 3669, 25886, 1026, 25353, 10936, 1030, 20643, 9006, 1028, 8299, 1024, 1013, 1013, 7479, 29337, 28251, 8586, 5358, 1013, 3422, 1029, 1058, 1027, 1053, 2063, 2620, 4328, 2629, 2497, 2860, 2683, 16409, 9587, 14945, 25353, 10936, 3669, 25886, 1026, 25353, 10936, 1030, 20643, 9006, 1028, 8299, 1024, 1013, 1013, 2739, 15396, 5643, 9006, 1013, 2739, 1013, 4021, 5643, 1003, 1016, 24700, 7974, 2015, 1013, 6027, 1013, 2466, 1013, 17350, 23809, 2100, 28332, 4274, 2003, 6179, 2005, 2437, 1996, 3606, 2124, 1010, 2758, 2852, 24404, 15222, 2099, 1036, 1036, 15490, 17761, 1010, 2238, 1015, 1011, 1048, 15185, 1011, 30522, 16595, 8067, 30523, 1011, 25269, 2497, 1011, 1011, 102]
5
1


SyntaxError: invalid syntax (942630328.py, line 4)

In [67]:
import numpy as np
np.argwhere(tokens == 30524).size

0

In [53]:
tokens = np.array(tokens)

In [49]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(config.bert_path, additional_special_tokens=["[E11]", "[E12]", "[E21]", "[E22]"])
tokens = [101, 2268, 1011, 5757, 1011, 6021, 2102, 2692, 2549, 1024, 5354, 1024, 2184, 4012, 1028, 2626, 1024, 9587, 14945, 25353, 10936, 3669, 25886, 1026, 25353, 10936, 1030, 20643, 9006, 1028, 8299, 1024, 1013, 1013, 7479, 29337, 28251, 8586, 5358, 1013, 3422, 1029, 1058, 1027, 1053, 2063, 2620, 4328, 2629, 2497, 2860, 2683, 16409, 9587, 14945, 25353, 10936, 3669, 25886, 1026, 25353, 10936, 1030, 20643, 9006, 1028, 8299, 1024, 1013, 1013, 2739, 15396, 5643, 9006, 1013, 2739, 1013, 4021, 5643, 1003, 1016, 24700, 7974, 2015, 1013, 6027, 1013, 2466, 1013, 17350, 23809, 2100, 28332, 4274, 2003, 6179, 2005, 2437, 1996, 3606, 2124, 1010, 2758, 2852, 24404, 15222, 2099, 1036, 1036, 15490, 17761, 1010, 2238, 1015, 1011, 1048, 15185, 1011, 30522, 16595, 8067, 30523, 1011, 25269, 2497, 1011, 1011, 102]
print(tokenizer.convert_ids_to_tokens(tokens))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


['[CLS]', '2009', '-', '06', '-', '03', '##t', '##0', '##4', ':', '59', ':', '10', 'com', '>', 'wrote', ':', 'mo', '##hd', 'sy', '##az', '##li', 'mahmud', '<', 'sy', '##az', '@', 'yahoo', '##com', '>', 'http', ':', '/', '/', 'www', '##you', '##tub', '##ec', '##om', '/', 'watch', '?', 'v', '=', 'q', '##e', '##8', '##mi', '##5', '##b', '##w', '##9', '##dc', 'mo', '##hd', 'sy', '##az', '##li', 'mahmud', '<', 'sy', '##az', '@', 'yahoo', '##com', '>', 'http', ':', '/', '/', 'news', '##asia', '##one', '##com', '/', 'news', '/', 'asia', '##one', '%', '2', '##bn', '##ew', '##s', '/', 'malaysia', '/', 'story', '/', 'a1', '##stor', '##y', '##200', 'internet', 'is', 'useful', 'for', 'making', 'the', 'truth', 'known', ',', 'says', 'dr', 'maha', '##thi', '##r', '`', '`', 'kuala', 'lumpur', ',', 'june', '1', '-', 'l', '##rb', '-', '[E11]', 'bern', '##ama', '[E12]', '-', 'rr', '##b', '-', '-', '[SEP]']


In [70]:
tokenizer.decode(tokens)

'[CLS] 2009 - 06 - 03t04 : 59 : 10 com > wrote : mohd syazli mahmud < syaz @ yahoocom > http : / / wwwyoutubecom / watch? v = qe8mi5bw9dc mohd syazli mahmud < syaz @ yahoocom > http : / / newsasiaonecom / news / asiaone % 2bnews / malaysia / story / a1story200 internet is useful for making the truth known, says dr mahathir ` ` kuala lumpur, june 1 - lrb - [E11] bernama [E12] - rrb - - [SEP]'

In [42]:
cnt = 0
for x in train_data:
    if 30524 not in x['tokens']:
       cnt += 1 

In [43]:
cnt

0

In [None]:
train_data[1]

In [72]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(config.bert_path, additional_special_tokens=["[E11]", "[E12]", "[E21]", "[E22]"])
tokenizer.convert_tokens_to_ids("[MASK]")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


103

In [98]:
from transformers import BertForMaskedLM
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
text = ["[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [MASK] created the Muppets . [SEP]", "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [MASK] created the Muppets . [SEP]"]

input_ids = tokenizer.encode(text, return_tensors="pt")
mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)
output = model(input_ids, return_dict=True)
logits = output.logits


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [102]:
input_ids

tensor([[101, 100, 100, 102]])

In [100]:
tokens = input_ids.tolist()
mask_idx = np.argwhere(np.array(tokens) == 103)[0][0]

IndexError: index 0 is out of bounds for axis 0 with size 0

In [96]:
mask_output = []

for i in range(input_ids.shape[0]):
    instance_output = torch.index_select(logits, 0, torch.tensor(i))
    instance_output = torch.index_select(instance_output, 1, torch.tensor(mask_idx))
    mask_output.append(instance_output)
mask_output = torch.cat(mask_output, dim=0)
    
    

In [97]:
mask_output.shape

torch.Size([1, 1, 30522])

In [8]:
import model.bert_encoder
from model.bert_encoder import Bert_EncoderMLM
config.pattern = "entity_marker_mask"
model = Bert_EncoderMLM(config)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


AssertionError: Torch not compiled with CUDA enabled

In [5]:
from config import Config
config = Config('config.ini')

In [6]:
config.infonce_temprature

AttributeError: 'Config' object has no attribute 'infonce_temprature'

In [2]:
additional_special_tokens = ["[E11]", "[E12]", "[E21]", "[E22]"]
additional_special_tokens.extend([f"[REL{i}]" for i in range(1, 50 + 1)])

In [7]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', additional_special_tokens=additional_special_tokens)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
tokenizer.convert_tokens_to_ids("[REL1]")

30526

In [4]:
import torch
import torch.nn as nn
m = nn.Softmax(dim=1)
input = torch.randn(2, 3) + 1e8
output = m(input)

In [5]:
output

tensor([[0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333]])

In [8]:
V = torch.rand(1,3000)
C = torch.rand(5,768)
W = torch.rand(3000,768)
output = torch.matmul(V,W)
output = torch.matmul(output,C.T)
output.shape


torch.Size([1, 5])

In [14]:
a = torch.rand(1,1)
b = torch.rand(1,5)
temp = torch.cat([a,b],dim=1).squeeze()

In [17]:
m = nn.Softmax(dim=0)
output = m(temp)

In [18]:
output

tensor([0.1380, 0.1012, 0.2047, 0.1749, 0.1983, 0.1830])

In [19]:
a = torch.tensor(1.0)
b = torch.rand(3)

In [24]:
torch.cat([a.unsqueeze(0),b],dim=0)

tensor([1.0000, 0.0674, 0.6470, 0.4542])

In [31]:
h = torch.tensor([61.01224136352539 ,-23.5942, 162.7764,  70.3629, -92.8090,  34.2552])
m(h/abs(h).max())

tensor([0.1736, 0.1032, 0.3244, 0.1839, 0.0675, 0.1473])

In [30]:
h/abs(h).max()

tensor([ 0.3748, -0.1449,  1.0000,  0.4323, -0.5702,  0.2104])

In [29]:
abs(h)

tensor([ 61.0122,  23.5942, 162.7764,  70.3629,  92.8090,  34.2552])

In [1]:
from transformers import BertForMaskedLM
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

  from .autonotebook import tqdm as notebook_tqdm
Downloading: 100%|██████████| 420M/420M [00:40<00:00, 10.8MB/s] 
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
for n,p in model.named_parameters():
    if "cls" in n:
        print(n)

cls.predictions.bias
cls.predictions.transform.dense.weight
cls.predictions.transform.dense.bias
cls.predictions.transform.LayerNorm.weight
cls.predictions.transform.LayerNorm.bias


In [6]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [None]:

# FREEZE LM HEAD 
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "7"
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import sys
import json
import gc
from tqdm import tqdm
from sklearn.cluster import KMeans
from encode import BERTMLMSentenceEncoderPrompt
from dataprocess import data_sampler_bert_prompt_deal_first_task_sckd
from model import proto_softmax_layer_bertmlm_prompt
from dataprocess import get_data_loader_bert_prompt
from util import set_seed
import wandb
import argparse

wandb.login(
    anonymous = 'allow',
    relogin = True,
    key = '9e33535aa8c9fcaa7fc1dfa97a70d9de5107ad37'
)

def eval_model(config, basemodel, test_set, mem_relations,seen_relations_ids):
    basemodel.eval()

    test_dataloader = get_data_loader_bert_prompt(config, test_set, shuffle=False, batch_size=30)
    allnum= 0.0
    correctnum = 0.0
    with torch.no_grad():
        for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext, lengths,
                typelabels, masks, mask_pos) in enumerate(test_dataloader):

            sentences = sentences.to(config['device'])
            masks = masks.to(config['device'])
            mask_pos = mask_pos.to(config['device'])
            logits, rep , _ = basemodel(sentences, masks, mask_pos)

            distances = basemodel.get_mem_feature(rep)
            short_logits = distances

            
            for index, logit in enumerate(logits):
                score = short_logits[index]  # logits[index] + short_logits[index] + long_logits[index]
                allnum += 1.0

                golden_score = score[labels[index]]
                max_neg_score = -2147483647.0
                for i in seen_relations_ids :
                    if (i != labels[index]) and score[i] > max_neg_score:
                        max_neg_score = score[i]
                if golden_score >= max_neg_score:
                    correctnum += 1

    acc = correctnum / allnum
    basemodel.train()
    return acc

def get_memory(config, model, proto_set):
    memset = []
    resset = []
    rangeset= [0]
    for i in proto_set:
        memset += i
        rangeset.append(rangeset[-1] + len(i))
    data_loader = get_data_loader_bert_prompt(config, memset, False, False)
    features = []
    for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext, lengths,
               typelabels, masks, mask_pos) in enumerate(data_loader):
        sentences = sentences.to(config['device'])
        masks = masks.to(config['device'])
        mask_pos = mask_pos.to(config['device'])
        feature = model.get_feature(sentences, masks, mask_pos)
        features.append(feature)
    features = np.concatenate(features)

    protos = []
    for i in range(len(proto_set)):
        protos.append(torch.tensor(features[rangeset[i]:rangeset[i+1],:].mean(0, keepdims = True)))
    protos = torch.cat(protos, 0)
    return protos

def select_data(mem_set, proto_memory, config, model, divide_train_set, num_sel_data, current_relations, selecttype):
    ####select data according to selecttype
    #selecttype is 0: cluster for every rel
    #selecttype is 1: use ave embedding
    rela_num = len(current_relations)
    for i in range(0, rela_num):
        thisrel = current_relations[i]
        if thisrel in mem_set.keys():
            #print("have set mem before")
            mem_set[thisrel] = {'0': [], '1': {'h': [], 't': []}}
            proto_memory[thisrel] = []
        else:
            mem_set[thisrel] = {'0': [], '1': {'h': [], 't': []}}
        thisdataset = divide_train_set[thisrel]
        data_loader = get_data_loader_bert_prompt(config, thisdataset, False, False)
        features = []
        for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext, lengths,
                typelabels, masks, mask_pos) in enumerate(data_loader):
            sentences = sentences.to(config['device'])
            masks = masks.to(config['device'])
            mask_pos = mask_pos.to(config['device'])
            feature = model.get_feature(sentences, masks, mask_pos)
            features.append(feature)
        features = np.concatenate(features)
        num_clusters = min(num_sel_data, len(thisdataset))
        if selecttype == 0:
            kmeans = KMeans(n_clusters=num_clusters, random_state=0)
            distances = kmeans.fit_transform(features)
            for i in range(num_clusters):
                sel_index = np.argmin(distances[:, i])
                instance = thisdataset[sel_index]
                ###change tylelabel
                instance[11] = 3
                ###add to mem data
                mem_set[thisrel]['0'].append(instance)  ####positive sample
                cluster_center = kmeans.cluster_centers_[i]
                proto_memory[thisrel].append(instance)
        elif selecttype == 1:
            #print("use average embedding")
            samplenum = features.shape[0]
            veclength = features.shape[1]
            sumvec = np.zeros(veclength)
            for j in range(samplenum):
                sumvec += features[j]
            sumvec /= samplenum

            ###find nearest sample
            mindist = 100000000
            minindex = -100
            for j in range(samplenum):
                dist = np.sqrt(np.sum(np.square(features[j] - sumvec)))
                if dist < mindist:
                    minindex = j
                    mindist = dist
            #print(minindex)
            instance = thisdataset[j]
            ###change tylelabel
            instance[11] = 3
            mem_set[thisrel]['0'].append(instance)
            proto_memory[thisrel].append(instance)
        else:
            print("error select type")
    #####to get negative sample  mem_set[thisrel]['1']
    if rela_num > 1:
        ####we need to sample negative samples
        allnegres = {}
        for i in range(rela_num):
            thisnegres = {'h':[],'t':[]}
            currel = current_relations[i]
            thisrelposnum = len(mem_set[currel]['0'])
            #assert thisrelposnum == num_sel_data
            #allnum = list(range(thisrelposnum))
            for j in range(thisrelposnum):
                thisnegres['h'].append(mem_set[currel]['0'][j][3])
                thisnegres['t'].append(mem_set[currel]['0'][j][5])
            allnegres[currel] = thisnegres
        ####get neg sample
        for i in range(rela_num):
            togetnegindex = (i + 1) % rela_num
            togetnegrelname = current_relations[togetnegindex]
            mem_set[current_relations[i]]['1']['h'].extend(allnegres[togetnegrelname]['h'])
            mem_set[current_relations[i]]['1']['t'].extend(allnegres[togetnegrelname]['t'])
    return mem_set

def select_data_all(mem_set, proto_memory, config, model, divide_train_set, num_sel_data, current_relations, selecttype):
    ####select data according to selecttype
    #selecttype is 0: cluster for every rel
    #selecttype is 1: use ave embedding
    rela_num = len(current_relations)
    for i in range(0, rela_num):
        thisrel = current_relations[i]
        if thisrel in mem_set.keys():
            #print("have set mem before")
            mem_set[thisrel] = {'0': [], '1': {'h': [], 't': []}}
            proto_memory[thisrel].pop()
        else:
            mem_set[thisrel] = {'0': [], '1': {'h': [], 't': []}}
        thisdataset = divide_train_set[thisrel]
        # print(len(thisdataset))
        for i in range(len(thisdataset)):
            instance = thisdataset[i]
            ###change tylelabel
            instance[11] = 3
            ###add to mem data
            mem_set[thisrel]['0'].append(instance)
            proto_memory[thisrel].append(instance)
    if rela_num > 1:
        ####we need to sample negative samples
        allnegres = {}
        for i in range(rela_num):
            thisnegres = {'h':[],'t':[]}
            currel = current_relations[i]
            thisrelposnum = len(mem_set[currel]['0'])
            #assert thisrelposnum == num_sel_data
            #allnum = list(range(thisrelposnum))
            for j in range(thisrelposnum):
                thisnegres['h'].append(mem_set[currel]['0'][j][3])
                thisnegres['t'].append(mem_set[currel]['0'][j][5])
            allnegres[currel] = thisnegres
        ####get neg sample
        for i in range(rela_num):
            togetnegindex = (i + 1) % rela_num
            togetnegrelname = current_relations[togetnegindex]
            mem_set[current_relations[i]]['1']['h'].extend(allnegres[togetnegrelname]['h'])
            mem_set[current_relations[i]]['1']['t'].extend(allnegres[togetnegrelname]['t'])
    return mem_set

def train_model_with_hard_neg(config, model, mem_set, traindata, epochs, current_proto, seen_relation_ids, tokenizer, ifnegtive=0, threshold=0.2, use_loss5=True, only_mem=False):
    print('training data num: ' + str(len(traindata)))
    mem_data = []
    if len(mem_set) != 0:
        for key in mem_set.keys():
            mem_data.extend(mem_set[key]['0'])
    print('memory data num: '+ str(len(mem_data)))
    if only_mem==True:
        train_set = mem_data
    else:
        train_set = traindata + mem_data
    print('all train data: ' + str(len(train_set)))
    data_loader = get_data_loader_bert_prompt(config, train_set, batch_size=config['batch_size_per_step'])
    model.train()
    criterion = nn.CrossEntropyLoss()
    mseloss = nn.MSELoss()
    softmax = nn.Softmax(dim=0)
    lossfn = nn.MultiMarginLoss(margin=0.2)
    optimizer = optim.Adam(model.parameters(), config['learning_rate'])
    for epoch_i in range(epochs):
        model.set_memorized_prototypes_midproto(current_proto)
        losses1 = []
        losses2 = []
        losses3 = []
        losses4 = []
        losses5 = []
        losses6 = []

        lossesfactor1 = 0.0
        lossesfactor2 = 1.0
        lossesfactor3 = 1.0
        lossesfactor4 = 0.0
        if use_loss5 == True:
            lossesfactor5 = 1.0
        else:
            lossesfactor5 = 0.0
        lossesfactor6 = 0.0
        for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext, lengths,
            typelabels, masks, mask_pos) in enumerate(data_loader):
            model.zero_grad()
            labels = labels.to(config['device'])
            typelabels = typelabels.to(config['device'])  ####0:rel  1:pos(new train data)  2:neg  3:mem
            numofmem = 0
            numofnewtrain = 0
            allnum = 0
            memindex = []
            for index,onetype in enumerate(typelabels):
                if onetype == 1:
                    numofnewtrain += 1
                if onetype == 3:
                    numofmem += 1
                    memindex.append(index)
                allnum += 1

            sentences = sentences.to(config['device'])
            masks = masks.to(config['device'])
            mask_pos = mask_pos.to(config['device'])
            logits, rep , lmhead_output = model(sentences, masks, mask_pos)
            logits_proto = model.mem_forward(rep)

            loss1 = criterion(logits, labels)
            loss2 = criterion(logits_proto, labels)
            loss4 = lossfn(logits_proto, labels)
            loss3 = torch.tensor(0.0).to(config['device'])
            for index, logit in enumerate(logits):
                score = logits_proto[index]
                preindex = labels[index]
                maxscore = score[preindex]
                size = score.shape[0]
                maxsecondmax = [maxscore]
                secondmax = -100000
                for j in range(size):
                    if j != preindex and score[j] > secondmax:
                        secondmax = score[j]
                maxsecondmax.append(secondmax)
                for j in range(size):
                    if j != preindex and maxscore - score[j] < threshold:
                        maxsecondmax.append(score[j])
                maxsecond = torch.stack(maxsecondmax, 0)
                maxsecond = torch.unsqueeze(maxsecond, 0)
                la = torch.tensor([0]).to(config['device'])
                loss3 += criterion(maxsecond, la)
            loss3 /= logits.shape[0]
            
            loss5 = torch.tensor(0.0).to(config['device'])
            allusenum5 = 0
            # # --- test---
            # try:
            #     print("Prototype : ")
            #     print(model.prototypes.shape)
            # except:
            #     print("no model.prototypes")
            # # --- test---
                
            # --- add info_nce loss ---
            prototypes = model.prototypes.clone()
            infoNCE_loss = 0
            try:
                for i in range(rep.shape[0]):
                    neg_prototypes = [prototypes[rel_id] for rel_id in seen_relations_ids if rel_id != labels[i].item()]
                    neg_prototypes = torch.stack(neg_prototypes)
                    neg_prototypes.requires_grad_ = False
                    neg_prototypes = neg_prototypes.squeeze() # [num_neg_prototypes, dim]

                    f_pos = model.sentence_encoder.infoNCE_f(lmhead_output[i],rep[i] , temperature = config['infonce_temperature'])
                    f_neg = model.sentence_encoder.infoNCE_f(lmhead_output[i],neg_prototypes , temperature = config['infonce_temperature'])
                    f_concat = torch.cat([f_pos.unsqueeze(0),f_neg],dim = 0)

                    f_concat = torch.log(torch.max(f_concat , torch.tensor(1e-9).to(config.device)))

                    infoNCE_loss += -torch.log(softmax(f_concat)[0])
            except Exception as e:
                print(e.with_traceback())
                print("no infoNCE_loss")
            infoNCE_loss /= rep.shape[0]
            # --- add info_nce loss ---

            # --- add mlm loss ---
            mlm_labels = labels + 30522 - 1
            mlm_labels = mlm_labels.to(config['device'])
            mlm_labels.requires_grad = False
            mlm_loss = criterion(lmhead_output,mlm_labels)
            # --- add mlm loss ---
            for index in memindex:
                preindex = labels[index]
                if preindex in model.haveseenrelations:
                    loss5 += mseloss(softmax(rep[index]), softmax(model.prototypes[preindex]))
                allusenum5 += 1
            
            loss6 = torch.tensor(0.0).to(config['device'])
            allusenum6 = 0
            for index in memindex:
                preindex = labels[index]
                if preindex in model.haveseenrelations:
                    best_distrbution = model.mem_forward_update(rep[index].view(1, -1), model.bestproto)
                    current_distrbution = model.mem_forward_update(model.prototypes[preindex].view(1, -1), model.bestproto)
                    loss6 += mseloss(best_distrbution, current_distrbution)
                allusenum6 += 1
            
            if len(memindex) == 0:
                loss = loss1 * lossesfactor1 + loss2 * lossesfactor2 + loss3 * lossesfactor3 + loss4 * lossesfactor4
            else:
                loss5 = loss5 / allusenum5
                loss6 = loss6 / allusenum6
                loss = loss1 * lossesfactor1 + loss2 * lossesfactor2 + loss3 * lossesfactor3 + loss4 * lossesfactor4 + loss5 * lossesfactor5 + loss6 * lossesfactor6   + infoNCE_loss * config['infonce_lossfactor'] + mlm_loss * config['mlm_lossfactor']###with loss5
            loss.backward()
            losses1.append(loss1.item())
            losses2.append(loss2.item())
            losses3.append(loss3.item())
            losses4.append(loss4.item())
            losses5.append(loss5.item())
            losses6.append(loss6.item())
            # print(f" InfoNCE_loss: {infoNCE_loss.item()} , mlm_loss: {mlm_loss.item()}")
            wandb.log({"Loss1": loss1.item(), "Loss2": loss2.item(), "Loss3": loss3.item(), "Loss4": loss4.item(), "Loss5": loss5.item(), "Loss6": loss6.item(), "InfoNCE_loss": infoNCE_loss.item(), "mlm_loss": mlm_loss.item()})

            torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])#cxd
            optimizer.step()
        return model

def train_memory(config, model, mem_set, train_set, epochs, current_proto, original_vocab_size,seen_relation_ids, ifusemem=True, threshold=0.2):
    train_set = []
    if ifusemem:
        mem_data = []
        if len(mem_set)!=0:
            for key in mem_set.keys():
                mem_data.extend(mem_set[key]['0'])
        train_set.extend(mem_data)
    data_loader = get_data_loader_bert_prompt(config, train_set, batch_size = config['batch_size_per_step'])
    model.train()
    criterion = nn.CrossEntropyLoss()
    mseloss = nn.MSELoss()
    softmax = nn.Softmax(dim=0)
    lossfn = nn.MultiMarginLoss(margin=0.2)
    optimizer = optim.Adam(model.parameters(), config['learning_rate'])#cxd
    for epoch_i in range(epochs):
        model.set_memorized_prototypes_midproto(current_proto)
        losses1 = []
        losses2 = []
        losses3 = []
        losses4 = []
        losses5 = []
        losses6 = []

        lossesfactor1 = 0.0
        lossesfactor2 = 1.0
        lossesfactor3 = 1.0
        lossesfactor4 = 0.0
        lossesfactor5 = 1.0
        lossesfactor6 = 1.0

        for step, (labels, neg_labels, sentences, firstent, firstentindex, secondent, secondentindex, headid, tailid, rawtext,
                   lengths, typelabels, masks, mask_pos) in enumerate(tqdm(data_loader)):
            model.zero_grad()
            sentences = sentences.to(config['device'])
            masks = masks.to(config['device'])
            mask_pos = mask_pos.to(config['device'])
            logits, rep , lmhead_output = model(sentences, masks, mask_pos)
            logits_proto = model.mem_forward(rep)

            labels = labels.to(config['device'])
            loss1 = criterion(logits, labels)
            loss2 = criterion(logits_proto, labels)
            loss4 = lossfn(logits_proto, labels)
            loss3 = torch.tensor(0.0).to(config['device'])
            ###add triple loss
            for index, logit in enumerate(logits):
                score = logits_proto[index]
                preindex = labels[index]
                maxscore = score[preindex]
                size = score.shape[0]
                maxsecondmax = [maxscore]
                secondmax = -100000
                for j in range(size):
                    if j != preindex and score[j] > secondmax:
                        secondmax = score[j]
                maxsecondmax.append(secondmax)
                for j in range(size):
                    if j != preindex and maxscore - score[j] < threshold:
                        maxsecondmax.append(score[j])
                maxsecond = torch.stack(maxsecondmax, 0)
                maxsecond = torch.unsqueeze(maxsecond, 0)
                la = torch.tensor([0]).to(config['device'])
                loss3 += criterion(maxsecond, la)
            loss3 /= logits.shape[0]

            loss5 = torch.tensor(0.0).to(config['device'])

            for index, logit in enumerate(logits):
                preindex = labels[index]
                if preindex in model.haveseenrelations:
                    loss5 += mseloss(softmax(rep[index]), softmax(model.prototypes[preindex]))
            loss5 /= logits.shape[0] 

            loss6 = torch.tensor(0.0).to(config['device'])
            for index, logit in enumerate(logits):
                preindex = labels[index]
                if preindex in model.haveseenrelations:
                    best_distrbution = model.mem_forward_update(rep[index].view(1, -1), model.bestproto)
                    current_distrbution = model.mem_forward_update(model.prototypes[preindex].view(1, -1), model.bestproto)
                    loss6 += mseloss(best_distrbution, current_distrbution)
            loss6 /= logits.shape[0]
            
             # --- add info_nce loss ---
            prototypes = model.prototypes.clone()
            infoNCE_loss = 0
            try:
                for i in range(rep.shape[0]):
                    neg_prototypes = [prototypes[rel_id] for rel_id in seen_relations_ids if rel_id != labels[i].item()]
                    neg_prototypes = torch.stack(neg_prototypes)
                    neg_prototypes.requires_grad_ = False
                    neg_prototypes = neg_prototypes.squeeze() # [num_neg_prototypes, dim]

                    f_pos = model.sentence_encoder.infoNCE_f(lmhead_output[i],rep[i] , temperature = config['infonce_temperature'])
                    f_neg = model.sentence_encoder.infoNCE_f(lmhead_output[i],neg_prototypes , temperature = config['infonce_temperature'])
                    f_concat = torch.cat([f_pos.unsqueeze(0),f_neg],dim = 0)

                    infoNCE_loss += -torch.log(softmax(f_concat/abs(f_concat).max())[0])
            except Exception as e:
                print(e.with_traceback())
                print("no infoNCE_loss")
            infoNCE_loss /= rep.shape[0]

            # --- add info_nce loss ---

            # --- add mlm loss ---
            mlm_labels = labels + 30522 - 1
            mlm_labels = mlm_labels.to(config['device'])
            mlm_loss = criterion(lmhead_output,mlm_labels)
            # --- add mlm loss ---



            loss = loss1 * lossesfactor1 + loss2 * lossesfactor2 + loss3 * lossesfactor3 + loss4 * lossesfactor4  + loss5 * lossesfactor5 + loss6 * lossesfactor6 + infoNCE_loss * config['infonce_lossfactor'] + mlm_loss * config['mlm_lossfactor']
            loss.backward()
            losses1.append(loss1.item())
            losses2.append(loss2.item())
            losses3.append(loss3.item())
            losses4.append(loss4.item())
            losses5.append(loss5.item())
            losses6.append(loss6.item())
            wandb.log({"Loss1": loss1.item(), "Loss2": loss2.item(), "Loss3": loss3.item(), "Loss4": loss4.item(), "Loss5": loss5.item(), "Loss6": loss6.item(), "InfoNCE_loss": infoNCE_loss.item(), "mlm_loss": mlm_loss.item()})
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])#cxd
            optimizer.step()
    return model


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--task" , default="tacred" , type=str)
    parser.add_argument("--shot" , default=5 , type=int)
    args = parser.parse_args()

    if args.task == "tacred":
        f = open("config/config_tacred.json", "r")
    elif args.task == "fewrel":
        f = open("config/config_fewrel_5and10.json", "r")
    else:
        raise ValueError("task must be tacred or fewrel")
    
    config = json.loads(f.read())
    f.close()

    if args.task == "fewrel":
        config['relation_file'] = "data/fewrel/relation_name.txt"
        config['rel_index'] = "data/fewrel/rel_index.npy"
        config['rel_feature'] = "data/fewrel/rel_feature.npy"
        config['rel_des_file'] = "data/fewrel/relation_description.txt"
        config['num_of_relation'] = 80
        if args.shot == 5:
            print('fewrel 5 shot')
            config['rel_cluster_label'] = "data/fewrel/CFRLdata_10_100_10_5/rel_cluster_label_0.npy"
            config['training_file'] = "data/fewrel/CFRLdata_10_100_10_5/train_0.txt"
            config['valid_file'] = "data/fewrel/CFRLdata_10_100_10_5/valid_0.txt"
            config['test_file'] = "data/fewrel/CFRLdata_10_100_10_5/test_0.txt"
        elif args.shot == 10:
            config['rel_cluster_label'] = "data/fewrel/CFRLdata_10_100_10_10/rel_cluster_label_0.npy"
            config['training_file'] = "data/fewrel/CFRLdata_10_100_10_10/train_0.txt"
            config['valid_file'] = "data/fewrel/CFRLdata_10_100_10_10/valid_0.txt"
            config['test_file'] = "data/fewrel/CFRLdata_10_100_10_10/test_0.txt"
        else:
            print('fewrel 2 shot')
            config['rel_cluster_label'] = "data/fewrel/CFRLdata_10_100_10_2/rel_cluster_label_0.npy"
            config['training_file'] = "data/fewrel/CFRLdata_10_100_10_2/train_0.txt"
            config['valid_file'] = "data/fewrel/CFRLdata_10_100_10_2/valid_0.txt"
            config['test_file'] = "data/fewrel/CFRLdata_10_100_10_2/test_0.txt"
    else:
        config['relation_file'] = "data/tacred/relation_name.txt"
        config['rel_index'] = "data/tacred/rel_index.npy"
        config['rel_feature'] = "data/tacred/rel_feature.npy"
        config['num_of_relation'] = 41
        if args.shot == 5:
            config['rel_cluster_label'] = "data/tacred/CFRLdata_10_100_10_5/rel_cluster_label_0.npy"
            config['training_file'] = "data/tacred/CFRLdata_10_100_10_5/train_0.txt"
            config['valid_file'] = "data/tacred/CFRLdata_10_100_10_5/valid_0.txt"
            config['test_file'] = "data/tacred/CFRLdata_10_100_10_5/test_0.txt"
        else:
            config['rel_cluster_label'] = "data/tacred/CFRLdata_10_100_10_10/rel_cluster_label_0.npy"
            config['training_file'] = "data/tacred/CFRLdata_10_100_10_10/train_0.txt"
            config['valid_file'] = "data/tacred/CFRLdata_10_100_10_10/valid_0.txt"
            config['test_file'] = "data/tacred/CFRLdata_10_100_10_10/test_0.txt"

        
    config['device'] = torch.device('cuda' if torch.cuda.is_available() and config['use_gpu'] else 'cpu')
    config['n_gpu'] = torch.cuda.device_count()
    config['batch_size_per_step'] = int(config['batch_size'] / config["gradient_accumulation_steps"])
    config['neg_sampling'] = False


    config['device'] = torch.device('cuda' if torch.cuda.is_available() and config['use_gpu'] else 'cpu')
    config['n_gpu'] = torch.cuda.device_count()
    config['batch_size_per_step'] = int(config['batch_size'] / config["gradient_accumulation_steps"])
    config['neg_sampling'] = False

    config['first_task_k-way'] = 10
    config['k-shot'] = 5
    donum = 1
    epochs = 1
    threshold=0.1

    wandb.init(
    project = 'DATN',
    name = f"ConPL_{args.task}_{args.shot}_{config['infonce_lossfactor']}_{config['mlm_lossfactor']}",
    config = {
        'name': "ConPL",
        "task" : args.task,
        "shot" : "5",
        "infonce_lossfactor" : config['infonce_lossfactor'],
        "mlm_lossfactor" : config['mlm_lossfactor']
    }
)

    for m in range(donum):
        print(m)

        encoderforbase = BERTMLMSentenceEncoderPrompt(config)
        
        #freeze lm head --------------
        for name,param in encoderforbase.bert.named_parameters():
            if 'cls' in name:
                param.requires_grad = False
        #freeze lm head --------------

        original_vocab_size = len(list(encoderforbase.tokenizer.get_vocab()))
        print('Vocab size: %d'%original_vocab_size)
        if config["prompt"] == "hard-complex":
            template = 'the relation between e1 and e2 is mask . '
            print('Template: %s'%template)
        elif config["prompt"] == "hard-simple":
            template = 'e1 mask e2 . '
            print('Template: %s'%template)
        else:
            template = None
            print("no use soft prompt.")
        
        sampler = data_sampler_bert_prompt_deal_first_task_sckd(config, encoderforbase.tokenizer, template)
        modelforbase = proto_softmax_layer_bertmlm_prompt(encoderforbase, num_class=len(sampler.id2rel), id2rel=sampler.id2rel, drop=0, config=config)
        modelforbase = modelforbase.to(config["device"])

        sequence_results = []
        sequence_results_average = []
        result_whole_test = []
        result_whole_test_average = []
        all_allresults_array = []

        fr_all = []
        distored_all = []
        for rou in range(6): #6 times different seeds to get average results

            num_class = len(sampler.id2rel)
            print('random_seed: ' + str(config['random_seed'] + 100 * rou))
            set_seed(config, config['random_seed'] + 100 * rou)
            sampler.set_seed(config['random_seed'] + 100 * rou)

            #cxd
            proto_acc = [[] for i in range(num_class)]
            proto_embedding = [[] for i in range(num_class)]

            mem_set = {} ####  mem_set = {rel_id:{'0':[positive samples],'1':[negative samples]}} 换5个head 换5个tail
            mem_relations = []   ###not include relation of current task

            past_relations = []

            savetest_all_data = None
            saveseen_relations = []

            proto_memory = []

            for i in range(len(sampler.id2rel)):
                proto_memory.append([sampler.id2rel_pattern[i]])
            # print('proto_memory', proto_memory)
            oneseqres = []
            whole_acc = []
            allresults_list = []
            ##################################
            whichdataselecct = 1
            ifnorm = True
            ##################################
            id2rel = sampler.id2rel
            rel2id = sampler.rel2id
            seen_test_data_by_task = []
            for steps, (training_data, valid_data, test_data,test_all_data, seen_relations,current_relations) in enumerate(sampler):
                print('current training data num: ' + str(len(training_data)))
                seen_relations_ids = [rel2id[relation] + 1 for relation in seen_relations] # seen relation (list of int) (include relation of current task)
                current_relations_ids = [rel2id[relation] + 1 for relation in current_relations] # current relation (list of int)
                
                seen_test_data_by_task.append(test_data)
                savetest_all_data = [] # test data of all tasks (array of shape 8000x16)
                for tmp in test_all_data:
                    savetest_all_data.extend(tmp)
                saveseen_relations = seen_relations

                currentnumber = len(current_relations)
                print('current relations num: '+ str(currentnumber))
                divide_train_set = {}
                for relation in current_relations_ids:
                    divide_train_set[relation] = []  ##int
                for data in training_data:
                    divide_train_set[data[0]].append(data)
                print('current divide num: '+ str(len(divide_train_set)))


                current_proto = get_memory(config, modelforbase, proto_memory) #这时候的current_proto是根据81个关系的名称输入模型之中得到的81个fake embedding：[81, 200]
                select_data_all(mem_set, proto_memory, config, modelforbase, divide_train_set,
                            config['rel_memory_size'], current_relations_ids, 0)  ##config['rel_memory_size'] == 1 
                            #proto_memory中的样本根据divide_train_set(training_data划分对应类)来增加每个类对应K个样本，mem_set[thisrel] = {'0': [], '1': {'h': [], 't': []}} 0放正样例，1放负样例，datatype=3

                ###add to mem data
                mem_set_length = {}
                proto_memory_length = []
                for i in range(len(proto_memory)):
                    proto_memory_length.append(len(proto_memory[i]))
                for key in mem_set.keys():
                    mem_set_length[key] = len(mem_set[key]['0'])
                print("mem_set_length", mem_set_length)
                print("proto_memory_length", proto_memory_length)

                for j in range(1):
                    current_proto = get_memory(config, modelforbase, proto_memory)
                    modelforbase = train_model_with_hard_neg(config, modelforbase, mem_set, training_data, epochs,
                                                                current_proto, tokenizer = encoderforbase.tokenizer,seen_relation_ids=seen_relations_ids, ifnegtive=0,threshold=threshold, use_loss5=False)
                
                select_data(mem_set, proto_memory, config, modelforbase, divide_train_set,
                            config['rel_memory_size'], current_relations_ids, 0)  ##config['rel_memory_size'] == 1 
                
                mem_set_length = {}
                proto_memory_length = []
                for i in range(len(proto_memory)):
                    proto_memory_length.append(len(proto_memory[i]))
                for key in mem_set.keys():
                    mem_set_length[key] = len(mem_set[key]['0'])
                print("mem_set_length", mem_set_length)
                print("proto_memory_length", proto_memory_length)
                for j in range(1):
                    current_proto = get_memory(config, modelforbase, proto_memory)
                    modelforbase = train_model_with_hard_neg(config, modelforbase, mem_set, training_data, epochs,
                                                                current_proto, tokenizer = encoderforbase.tokenizer,seen_relation_ids=seen_relations_ids, ifnegtive=0,threshold=threshold)
                
                #add train memory
                current_proto = get_memory(config, modelforbase, proto_memory)
                modelforbase = train_memory(config, modelforbase, mem_set, training_data, epochs*3, current_proto, original_vocab_size,ifusemem= True,seen_relation_ids=seen_relations_ids, threshold=threshold)

                
                current_proto = get_memory(config, modelforbase, proto_memory)
                modelforbase.set_memorized_prototypes_midproto(current_proto)
                modelforbase.save_bestproto(current_relations_ids)#save bestproto
                mem_relations.extend(current_relations_ids)

                currentalltest = []
                for mm in range(len(test_data)):
                    currentalltest.extend(test_data[mm])

                #compute mean accuarcy
                results = [eval_model(config, modelforbase, item, mem_relations,seen_relations_ids) for item in seen_test_data_by_task] # results of all previous task + this task after training on current task
                allresults_list.append(results)
                results_average = np.array(results).mean() # average accuracy of all tasks after training on current task
                wandb.log({f"Round {rou} Average Accuracy": results_average})
                whole_acc.append(results_average)

                #compute whole accuarcy
                seen_test_set = []
                for seen_relation in seen_relations_ids:
                    seen_test_set.extend(test_all_data[seen_relation - 1]) # test_all_data is a list of test data of all relation (test_all_data[0] is test data of relation 1])
                thisstepres = eval_model(config, modelforbase, seen_test_set, mem_relations,seen_relations_ids) # combine all test data of all tasks and evaluate
                oneseqres.append(thisstepres)

            sequence_results.append(np.array(oneseqres)) # combine all test data of all tasks and evaluate
            sequence_results_average.append(np.array(whole_acc)) # evaluate each task and average

            allres = eval_model(config, modelforbase, savetest_all_data, saveseen_relations,seen_relations_ids) # eval on all test data of all tasks
            result_whole_test.append(allres)

            print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")
            print("after one epoch allres whole:\t" + str(allres))
            print(result_whole_test)

            allresults = [eval_model(config, modelforbase, item, num_class,seen_relations_ids) for item in seen_test_data_by_task]
            allresults_average = np.array(allresults).mean()
            result_whole_test_average.append(allresults_average)
            print("after one epoch allres average:\t" + str(allresults))
            print(result_whole_test_average)


            modelforbase = modelforbase.to('cpu')
            del modelforbase
            gc.collect()
            if config['device'] == 'cuda':
                torch.cuda.empty_cache()
            encoderforbase = BERTMLMSentenceEncoderPrompt(config)
            modelforbase = proto_softmax_layer_bertmlm_prompt(encoderforbase, num_class=len(sampler.id2rel), id2rel=sampler.id2rel, drop=0, config=config)
            modelforbase = modelforbase.to(config["device"])
        print("Final result: whole!")
        print(result_whole_test)
        for one in sequence_results:
            for item in one:
                sys.stdout.write('%.4f, ' % item)
            print('')
        avg_result_all_test = np.average(sequence_results, 0)
        for one in avg_result_all_test:
            sys.stdout.write('%.4f, ' % one)
        print('')
        print("Final result: average!")
        print(result_whole_test_average)
        for one in sequence_results_average:
            for item in one:
                sys.stdout.write('%.4f, ' % item)
            print('')
        avg_result_all_test_average = np.average(sequence_results_average, 0)
        for one in avg_result_all_test_average:
            sys.stdout.write('%.4f, ' % one)
        print('')
