In [1]:
import sys 

import re
import os

import tqdm
import json
import pickle

import shelve

import numpy as np
import collections
import itertools

from collections import Counter
import matplotlib.pyplot as plt

import paddle
from paddle.static import InputSpec

import time
import gc


In [2]:
with open('./work/word_to_ind.json') as f:
    word_to_ind = json.load(f)
ind_to_word = {i: w for w, i in word_to_ind.items()}
with open('./work/rel_to_ind.json') as f:
    rel_to_ind = json.load(f)
ind_to_rel = {i: w for w, i in rel_to_ind.items()}
with open('./work/type_to_ind.json') as f:
    type_to_ind = json.load(f)
with open('./work/distant_dict.json') as f:
    distant_dict = json.load(f)
    distant_dict = set((l[0], l[1]) for l in distant_dict)
ind_to_type = {i: w for w, i in type_to_ind.items()}

glove_file = r'data/data92342/glove.840B.300d.txt'

In [3]:
batch_size = 30

n_class = len(rel_to_ind)

no_components = 300

ner_emb = 30
coref_emb = 30
dep_emb = 30
sent_dist_emb = 30
dim_sent = 30

max_sents = 21

coref_maxlen = 60

sent_rel_max = 1800
token_max_len = 511
max_seq_len = 511

dim = [500, 250, 128, 105]
dim_2 = 60
drop_out = 0.5

learning_rate = 2e-4
epoch_n = 100

n_C = 0.1 # nearly clean samples selection parameter
n_V = 0.01 # correction proportion
n_R = 0.01 # relabeling proportion
t_w = 30 # warm up epoch




In [4]:
np.random.seed(1)
if not os.path.exists('./work/glove_emb_mat.npy'):
    glove_emb_mat = (np.random.rand(len(word_to_ind) + 2, no_components) - 0.5)
    if re.split('\.', glove_file)[-1] == 'bin':
        from gensim.models import KeyedVectors
        from gensim.test.utils import datapath
        import gc

        wv_from_bin = KeyedVectors.load_word2vec_format(datapath(glove_file), binary=True)
        for w, i in tqdm.tqdm(word_to_ind.items()):
            if w in wv_from_bin:
                glove_emb_mat[i, :] = wv_from_bin.get_vector(w)
        del wv_from_bin
        gc.collect()
        np.save('./work/glove_emb_mat.npy', glove_emb_mat)
    elif re.split('\.', glove_file)[-1] == 'txt':
        with open(glove_file, 'r', encoding='utf8') as f:
            for line in tqdm.tqdm(f):
                w, *vec = re.split(' ', line)
                if w in word_to_ind.keys():
                    glove_emb_mat[word_to_ind[w], :] = [float(v) for v in vec[:no_components]]
        np.save('./work/glove_emb_mat.npy', glove_emb_mat)
else:
    glove_emb_mat = np.load('./work/glove_emb_mat.npy')


  elif re.split('\.', glove_file)[-1] == 'txt':
  if re.split('\.', glove_file)[-1] == 'bin':


In [5]:
dis2idx = np.zeros((max_seq_len,), dtype=np.int64)
dis2idx[1] = 1
for i in range(2, 10):
    dis2idx[(2 ** (i - 1)):] = i

In [6]:
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(NpEncoder, self).default(obj)

def dump_data(raw_datas,save_path,y_path):
    data = []
    for temp in raw_datas:
        if sum(len(st) for st in temp['sents']) < token_max_len:
            data.append(temp)
    # example_data = []
    if os.path.exists(save_path):
        os.remove(save_path)
    if os.path.exists(y_path):
        os.remove(y_path)
    
    with open(y_path,'ab') as f_y:
        with open(save_path, 'ab') as f:
            save_index = 0
#         with open(save_path,'ab') as f:

            for index, temp in tqdm.tqdm(enumerate(data)):
                title = temp['title']
                sent_start_pos = np.cumsum([0] + [len(st) for st in temp['sents']])
                sent_org_token = [t for st in temp['sents'] for t in st]
                sent_coref_id = np.zeros((token_max_len,))
                sent_ner_id = np.zeros((token_max_len,))
                for idx, en in enumerate(temp['vertexSet']):
                    for ment in en:
                        em_pos = np.array(ment['pos']) + sent_start_pos[ment['sent_id']]
                        sent_coref_id[em_pos[0]:em_pos[1]] = idx + 1
                        sent_ner_id[em_pos[0]:em_pos[1]] = type_to_ind[ment['type']] + 1
                sent_coref_id = sent_coref_id.astype(int)
                sent_ner_id = sent_ner_id.astype(int)

                sent_glove_id = []
                for token in sent_org_token:
                    sent_glove_id.append(word_to_ind[token] if token in word_to_ind.keys() else 1)
                while len(sent_glove_id) < token_max_len:
                    sent_glove_id.append(0)
                sent_glove_id = np.array(sent_glove_id)
                rels = []
                entity_n = len(temp['vertexSet'])
                
                h_t_dict = {}
                for rel in temp['labels']:
                    # label_names.append(rel['r'])
                    if (rel['h'], rel['t']) not in h_t_dict.keys():
                        h_t_dict[(rel['h'], rel['t'])] = [rel['r']]
                    else:
                        h_t_dict[(rel['h'], rel['t'])].append(rel['r'])

                neg_set = []
                for i in range(entity_n):
                    for j in range(entity_n):
                        if i != j:
                            neg_set.append((i, j))
                if len(neg_set) > sent_rel_max:
                    np.random.shuffle(neg_set)
                neg_set = set(neg_set[:sent_rel_max])
                
                for h_i in range(entity_n):
                    for t_i in range(entity_n):
                        if h_i != t_i:
                            if (is_train and ((h_i, t_i) in h_t_dict.keys() or (h_i, t_i) in neg_set)) or is_validation:
                                h_pos = []
                                h_map_v = []
                                h_count = Counter([m['type'] for m in temp['vertexSet'][h_i]])
                                h_count_max = max(h_count.values())
                                h_type = [w for w in h_count if h_count[w] == h_count_max][0]
                                h_type_index = type_to_ind[h_type]
                                # type_names.append(h_type)

                                t_pos = []
                                t_map_v = []
                                t_count = Counter([m['type'] for m in temp['vertexSet'][t_i]])
                                t_count_max = max(t_count.values())
                                t_type = [w for w in t_count if t_count[w] == t_count_max][0]
                                t_type_index = type_to_ind[t_type]


                                # type_names.append(t_type)

                                label_indexs = []

                                delta_dis = temp['vertexSet'][h_i][0]['pos'][0] + \
                                            sent_start_pos[temp['vertexSet'][h_i][0]['sent_id']] - \
                                            temp['vertexSet'][t_i][0]['pos'][0] - \
                                            sent_start_pos[temp['vertexSet'][t_i][0]['sent_id']]
                                sent_dist = dis2idx[delta_dis] if delta_dis > 0 else -dis2idx[-delta_dis]

                                for h in temp['vertexSet'][h_i]:
                                    for t in temp['vertexSet'][t_i]:
                                        em1_pos = np.array(h['pos']) + sent_start_pos[h['sent_id']]
                                        em1_pos_num = list(range(em1_pos[0], em1_pos[1]))
                                        h_pos.extend(em1_pos_num)
                                        em1_pos_num = list(range(em1_pos[0], em1_pos[1]))
                                        em1_map_value = [1 / len(em1_pos_num) / len(temp['vertexSet'][h_i]) for i in
                                                        range(em1_pos[1] - em1_pos[0])]
                                        h_map_v.extend(em1_map_value)

                                        em2_pos = np.array(t['pos']) + sent_start_pos[t['sent_id']]
                                        em2_pos_num = list(range(em2_pos[0], em2_pos[1]))
                                        t_pos.extend(em2_pos_num)
                                        em2_pos_num = list(range(em2_pos[0], em2_pos[1]))
                                        em2_map_value = [1 / len(em2_pos_num) / len(temp['vertexSet'][t_i]) for i in
                                                        range(em2_pos[1] - em2_pos[0])]
                                        t_map_v.extend(em2_map_value)                                        
                                labels = h_t_dict[(h_i, t_i)] if (h_i, t_i) in h_t_dict.keys() else ['None']
                                for label in labels:
                                    label_indexs.append(rel_to_ind[label])

                                temp_rel = {'h_pos': h_pos,
                                            'h_map_v': h_map_v,
                                            'h_type_index':h_type_index,
                                            't_pos': t_pos,
                                            't_map_v': t_map_v,
                                            't_type_index':t_type_index,
                                            'sent_dist': sent_dist,
                                            'label_indexs':label_indexs[0]}
                                rels.append(temp_rel)
                # if len(rels) > 0 and any(r_>0 for r in rels for r_ in r['label_indexs']):
                if len(rels) > 0 and any(r['label_indexs']>0 for r in rels ):
                    example = {'title':title,
                    'index':index,
                    'sent_glove_id':sent_glove_id,
                    'sent_coref_id':sent_coref_id,
                    'sent_ner_id':sent_ner_id,
                    'rels':rels}
                    
                    pickle.dump(example,f)
                    pickle.dump([r['label_indexs'] for r in rels],f_y)
                    
                    
                    # f.write(pickle.dumps(example)+'\n')
                    # f_y.write(pickle.dumps([r['label_indexs'] for r in rels])+'\n')

In [7]:
# with open(r'data/data92314/train_annotated.json') as f:
#     raw = json.load(f)
# dump_data(raw,'./work/train_annotated_data.pkl','./work/train_annotated_y.pkl')

# with open('./train_distant.json') as f:
#     raw = json.load(f)
# dump_data(raw,'./train_distant_data','./train_distant_y.pkl')
# # # dump_data(raw[:int(len(raw)/2)],'/home/aistudio/work/train_distant_data1.pkl','/home/aistudio/work/train_distant_y1.pkl')
# # # dump_data(raw[int(len(raw)/2):],'/home/aistudio/work/train_distant_data2.pkl','/home/aistudio/work/train_distant_y2.pkl')

# with open('data/data92314/dev.json') as f:
#     raw = json.load(f)
# dump_data(raw,'./work/dev_data.pkl','./work/dev_y.pkl')


In [8]:
train_annotated_y = []
with open('./work/train_annotated_y.pkl','rb') as f:
    while 1:
        try:
            temp = pickle.load(f)
            train_annotated_y.append(temp)
        except EOFError:
            break
# train_distant_y = []
# with open('./train_distant_y.pkl','rb') as f:
#     while 1:
#         try:
#             temp = pickle.load(f)
#             train_distant_y.append(temp)
#         except EOFError:
#             break
dev_y = []
with open('./work/dev_y.pkl','rb') as f:
    while 1:
        try:
            temp = pickle.load(f)
            dev_y.append(temp)
        except EOFError:
            break

train_annotated_y_start = np.cumsum([0] + [len(d) for d in train_annotated_y])
train_annotated_y = np.array([y for d in train_annotated_y for y in d])

dev_y_start = np.cumsum([0] + [len(d) for d in dev_y])
dev_y = np.array([y for d in dev_y for y in d])





In [9]:
def get_random_label(train_Y, label_account = None):
    if label_account == None:
        noise_int = np.random.randint(low=0,high=n_class-1,size=train_Y.shape)
        train_Y_rand =  np.mod(noise_int + train_Y,n_class)
        train_Y_rand = np.array(train_Y_rand)
    else:
        dis_prob = [label_account[i]/sum(j for j in label_account.values()) for i in label_account]
        train_Y_rand = np.random.choice([i for i in label_account.keys()],train_Y.shape[0],p=dis_prob)
    return train_Y_rand

In [10]:
def load_data_new(f, ):    
    def data_generator():
        start = 0
        i=0
        temp_glove_id = []
        temp_coref_id = []
        temp_ner_id = []
        rels_list = []
        while 1:
            try:
                d = pickle.load(f)
                if (i==start or (i-start)%batch_size!=0):
                    temp_glove_id.append(np.array(d['sent_glove_id']).reshape(1,-1))
                    temp_coref_id.append(np.array(d['sent_coref_id']).reshape(1,-1))
                    temp_ner_id.append(np.array(d['sent_ner_id']).reshape(1,-1))
                    rels_list.append(d['rels'])
                else:
                    yield np.concatenate(temp_glove_id,axis=0).astype(np.int32),np.concatenate(temp_coref_id,axis=0).astype(np.int32),np.concatenate(temp_ner_id,axis=0).astype(np.int32),rels_list
                    temp_glove_id = [np.array(d['sent_glove_id']).reshape(1,-1)]
                    temp_coref_id = [np.array(d['sent_coref_id']).reshape(1,-1)]
                    temp_ner_id = [np.array(d['sent_ner_id']).reshape(1,-1)]
                    rels_list = [d['rels']]
                i+=1
            except EOFError:
                if len(temp_glove_id)>0:
                    yield np.concatenate(temp_glove_id,axis=0).astype(np.int32),np.concatenate(temp_coref_id,axis=0).astype(np.int32),np.concatenate(temp_ner_id,axis=0).astype(np.int32),rels_list
                break

    return data_generator

In [11]:
class RE_model(paddle.nn.Layer):
    def __init__(self):
        super(RE_model, self).__init__()
        self.word_embedding = paddle.nn.Embedding(glove_emb_mat.shape[0],glove_emb_mat.shape[1],padding_idx=0)
        self.word_embedding.weight.set_value(glove_emb_mat.astype(np.float32))
        self.coref_emb_mat = paddle.nn.Embedding(coref_maxlen, coref_emb,padding_idx=0)
        self.ner_emb_mat = paddle.nn.Embedding(len(type_to_ind) + 1,ner_emb,padding_idx=0)
        self.seq_lstm = paddle.nn.LSTM(glove_emb_mat.shape[1]+coref_emb+ner_emb,dim[2],direction='bidirectional')
        self.dropout = paddle.nn.Dropout(p=drop_out)
        self.head_linear = paddle.nn.Linear(2 * dim[2], dim[2])
        self.tail_linear = paddle.nn.Linear(2 * dim[2], dim[2])
        self.head_type_emb = paddle.nn.Embedding(len(type_to_ind), dim[2])
        self.head_type_emb.weight.set_value(np.zeros((len(type_to_ind),dim[2])).astype(np.float32))
        self.tail_type_emb = paddle.nn.Embedding(len(type_to_ind), dim[2])
        self.tail_type_emb.weight.set_value(np.zeros((len(type_to_ind),dim[2])).astype(np.float32))

        self.sent_dist_emb = paddle.nn.Embedding(max_sents,sent_dist_emb)

        self.bilinear = paddle.nn.Bilinear(dim[2]+sent_dist_emb, dim[2]+sent_dist_emb, dim[3])

        self.final_linear = paddle.nn.Linear(dim[3],len(rel_to_ind))
        self.selu = paddle.nn.SELU()
        self.relu = paddle.nn.ReLU()

    def forward(self,
                word_id,
                coref_id,
                ner_id,
                h_pos,h_v,
                t_pos,t_v,
                rel_n,
                h_type_index,
                t_type_index,
                ht_dist_index,
                th_dist_index,):
        temp_batch_size = word_id.shape[0]
        
        seq_emb = self.seq_lstm(paddle.concat([self.word_embedding(word_id),
                                               self.coref_emb_mat(coref_id), 
                                               self.ner_emb_mat(ner_id)],axis=-1))
        seq_emb = paddle.reshape(self.dropout(seq_emb[0]), (-1, 2 * dim[2]))
        
        head_seq_emb = self.relu(self.head_linear(seq_emb))
        tail_seq_emb = self.relu(self.tail_linear(seq_emb))
        
        h_mask = paddle.scatter_nd(h_pos,paddle.to_tensor(h_v,dtype='float32'),shape=(rel_n,temp_batch_size * max_seq_len))
        t_mask = paddle.scatter_nd(t_pos,paddle.to_tensor(t_v,dtype='float32'),shape=(rel_n,temp_batch_size * max_seq_len))
        
        h_temp = paddle.matmul(h_mask, head_seq_emb)
        t_temp = paddle.matmul(t_mask, tail_seq_emb)

        temp_head_bias = self.head_type_emb(h_type_index)
        temp_tail_bias = self.tail_type_emb(t_type_index)

        ht_dist_temp = self.sent_dist_emb(ht_dist_index)
        th_dist_temp = self.sent_dist_emb(th_dist_index)

        h_temp = paddle.concat([h_temp + temp_head_bias, ht_dist_temp],axis=1)
        t_temp = paddle.concat([t_temp + temp_tail_bias, th_dist_temp],axis=1)
        
        sub_temp = self.selu(self.bilinear(h_temp,t_temp))
        m_temp = self.final_linear(sub_temp)
        
        return m_temp, sub_temp

class Orthogonal_loss17(paddle.nn.Layer):
    def __init__(self,):
        super(Orthogonal_loss17, self).__init__()
        
    def forward(self, x, ):
        n = x.shape[0]
        m = x.shape[1]

        I = paddle.eye(m)
        e = x - x.mean(axis=0, keepdim = True)
        m_nonz = (e.sum(axis = 0) != 0).sum()
        
        cov = e.t() @ e
        
        cov2 = cov ** 2
        # cov2 = paddle.abs(cov)
        
        select_i = paddle.argmax(cov2 - cov2 * I, axis = 1)
        cov_m = (paddle.nn.functional.one_hot(select_i, m) * cov2).sum()
        cov_i = (I * cov).sum()
        
        result = (cov_m-cov_i) / m_nonz / n
        return result
input_word = InputSpec((-1,max_seq_len), np.int32, 'word')
input_coref = InputSpec((-1,max_seq_len), np.int32, 'coref')
input_ner = InputSpec((-1,max_seq_len), np.int32, 'ner')
input_h_pos = InputSpec((-1,2), np.int32, 'h_pos')
input_h_v = InputSpec((-1,), np.float, 'h_v')
input_t_pos = InputSpec((-1,2), np.int32, 't_pos')
input_t_v = InputSpec((-1,), np.float32, 't_v')
input_rel_n = InputSpec((1,), np.int32, 'rel_n')
input_h_type = InputSpec((-1,), np.int32, 'h_type')
input_t_type = InputSpec((-1,), np.int32, 't_type')
input_ht = InputSpec((-1,), np.int32, 'ht')
input_th = InputSpec((-1,), np.int32, 'th')
input_label = InputSpec((-1,), np.int32, 'label')
re_network1 = RE_model()
cold_loss_fn = paddle.nn.CrossEntropyLoss()

sch = paddle.optimizer.lr.MultiStepDecay(learning_rate=learning_rate, milestones=[30, 50], gamma=0.1)
opt = paddle.optimizer.Adam(learning_rate=sch,
                            parameters=re_network1.parameters(),
                            epsilon=1e-06, 
                            grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0))





W0425 22:11:38.942303  2114 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0425 22:11:38.946472  2114 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.


In [12]:
def relevant_hard_np(x,):
    n = x.shape[0]
    nz = x.shape[1]
    r = np.corrcoef(x.T)

    r = r ** 2
    r[np.isnan(r)] = 0.0

    return np.mean(np.max(r - r * np.eye(nz), axis = -1))
def transform_mydata(glove_id,coref_id,ner_id,rels_list):
    start_i_list = [0] + np.cumsum([len(rels) for rels in rels_list])[:-1].tolist()
    rel_n = np.reshape(np.sum([len(rels) for rels in rels_list]).astype(np.int32),(1,))
    h_i = [start_i_list[batch_id] + i for batch_id, rels in enumerate(rels_list) for i,rel in enumerate(rels) for _ in rel['h_pos']]
    h_j = [batch_id * max_seq_len + p for batch_id, rels in enumerate(rels_list) for rel in rels for p in rel['h_pos']]
    # h_v = [1./len(rel['h_pos']) for batch_id, rels in enumerate(rels_batch) for rel in rels for p in rel['h_pos']]
    h_v = np.array([v for batch_id, rels in enumerate(rels_list) for rel in rels for v in rel['h_map_v']])
    t_i = [start_i_list[batch_id] + i for batch_id, rels in enumerate(rels_list) for i,rel in enumerate(rels) for _ in rel['t_pos']]
    t_j = [batch_id * max_seq_len + p for batch_id, rels in enumerate(rels_list) for rel in rels for p in rel['t_pos']]
    t_v = np.array([v for batch_id, rels in enumerate(rels_list) for rel in rels for v in rel['t_map_v']])
    h_pos = np.array([h_i,h_j],dtype=np.int32).T
    t_pos = np.array([t_i,t_j],dtype=np.int32).T
    h_type = np.array([rel['h_type_index'] for batch_id, rels in enumerate(rels_list) for rel in rels],dtype=np.int32)
    t_type = np.array([rel['t_type_index'] for batch_id, rels in enumerate(rels_list) for rel in rels],dtype=np.int32)
    ht = np.array([rel['sent_dist'] + (max_sents-1)/2 for batch_id, rels in enumerate(rels_list) for rel in rels],dtype=np.int32)
    th = np.array([-rel['sent_dist'] + (max_sents-1)/2 for batch_id, rels in enumerate(rels_list) for rel in rels],dtype=np.int32)
    return [paddle.to_tensor(glove_id.astype(np.int32)),
            paddle.to_tensor(coref_id.astype(np.int32)),
            paddle.to_tensor(ner_id.astype(np.int32)),
            paddle.to_tensor(h_pos),
            paddle.to_tensor(h_v),
            paddle.to_tensor(t_pos),
            paddle.to_tensor(t_v),
            paddle.to_tensor(rel_n),
            paddle.to_tensor(h_type),
            paddle.to_tensor(t_type),
            paddle.to_tensor(ht),
            paddle.to_tensor(th)]    
def validation_new(model):
    dev_f = open('./work/dev_data.pkl','rb')
    dev_loader = load_data_new(dev_f,)
    model.eval()
    loss_eval = 0
    f1_eval = 0
    total_y_list = []
    pred_y_list = []
    loss_b_eval = 0
    for batch_id, data in enumerate(dev_loader()):
        with paddle.no_grad():
            glove_id,coref_id,ner_id,rels_list = data
            mydata = transform_mydata(glove_id,coref_id,ner_id,rels_list)
            y_cpu = dev_y[dev_y_start[batch_id * batch_size] : dev_y_start[min((batch_id+1)*batch_size, len(dev_y_start)-1)]].astype(np.int64)
            y_data = paddle.to_tensor(y_cpu)
            y_onehot = paddle.nn.functional.one_hot(paddle.reshape(y_data,(-1,)),num_classes=n_class)
            logits, feature = model(*mydata)
            loss_b = relevant_hard_np(feature.numpy())
            try:
                _loss = cold_loss_fn(logits,y_data)
            except:
                _loss = cold_loss_fn(logits,y_onehot)
            pred_y = np.argmax(logits.numpy(),axis=-1)

        total_y_list.append(y_cpu)
        pred_y_list.append(pred_y)

        loss_eval += _loss.numpy()
        loss_b_eval += loss_b
    
    total_y = np.concatenate(total_y_list,axis=0)
    pred_y_total = np.concatenate(pred_y_list,axis=0)

    correct_n = sum(np.logical_and(total_y == pred_y_total, total_y!=0))
    pred_n = sum(pred_y_total!=0)
    total_n = sum(total_y!=0)
    
    precision = correct_n / total_n
    recall = correct_n / pred_n
    f1 = 2 * precision * recall / (precision + recall)

    loss_eval/=(batch_id+1)
    loss_b_eval/=(batch_id+1)
    
    print('loss:%.4f, f1:%.4f, recall:%.4f, precision:%.4f, loss_b:%.4f'%(loss_eval,f1,recall,precision,loss_b_eval))

    return loss_eval,f1

In [13]:
def train_predict_model(loader,Yt_list,y_start,model,acc_list,):
    model.train()
    Py_temp = np.zeros((y_start[-1],),dtype=np.float32)
    Pred_temp = np.zeros((y_start[-1],),dtype=np.float32) 
    Probs_temp = np.zeros((y_start[-1],n_class),dtype=np.float32)
    feature_temp = np.zeros((y_start[-1],dim[3]),dtype=np.float32)
    Logits_temp = np.zeros((y_start[-1],n_class),dtype=np.float32)

    if len(Py_list)>0:
        Py_mean = Py_list[-1]
    else:
        Py_mean = np.zeros((y_start[-1],))
        
    OOD_mask = np.zeros((y_start[-1],),dtype=np.bool)
    for j_ in range(n_class):
        class_mask = Yt_list[0]==j_
        class_n = class_mask.sum()
        class_thres = np.sort(Py_mean[class_mask])[int((class_n-1)*0.01)]
        OOD_mask[np.logical_and(class_mask, Py_mean<= class_thres)]=True

    is_neg = (Yt_list[-1] == 0).ravel()   
    for batch_id, data in enumerate(loader()):
        mydata = transform_mydata(*data)
        y_cpu = Yt_list[-1][y_start[batch_id * batch_size] : y_start[min((batch_id+1)*batch_size, len(y_start)-1)]].astype(np.int64)
        Y_GPU = paddle.to_tensor(y_cpu)
        y_onehot = paddle.nn.functional.one_hot(paddle.reshape(Y_GPU,(-1,)),num_classes=n_class)
        logits, feature = model.forward(*mydata)
        
        probs = paddle.nn.functional.softmax(logits)
        Py = paddle.sum(y_onehot * probs, axis = -1)
        Pred = paddle.argmax(probs,axis=-1)
        logits_other = logits - logits * y_onehot
        Pred_other = paddle.argmax(logits_other,axis=-1)

        Py_temp[y_start[batch_id * batch_size] : y_start[min((batch_id+1)*batch_size, len(y_start)-1)]] = Py.numpy()
        Pred_temp[y_start[batch_id * batch_size] : y_start[min((batch_id+1)*batch_size, len(y_start)-1)]] = Pred.numpy()  
        
        Probs_temp[y_start[batch_id * batch_size] : y_start[min((batch_id+1)*batch_size, len(y_start)-1)]] = probs.numpy()
        Logits_temp[y_start[batch_id * batch_size] : y_start[min((batch_id+1)*batch_size, len(y_start)-1)]] = logits.numpy()
  

        feature_temp[y_start[batch_id * batch_size] : y_start[min((batch_id+1)*batch_size, len(y_start)-1)]] = feature.numpy()

        if epoch_id < t_w:
            loss = cold_loss_fn(logits,Y_GPU)
        else:         
            Y_GPU = paddle.where(paddle.to_tensor(OOD_mask[y_start[batch_id * batch_size] : y_start[min((batch_id+1)*batch_size, len(y_start)-1)]]), Pred_other, Y_GPU)
            loss = cold_loss_fn(logits,Y_GPU)
        
        loss.backward()
        opt.step()
        opt.clear_grad()

        if batch_id%int(3000/batch_size)==0:
            pred_y_temp = np.argmax(logits.numpy(),axis=-1)
            correct_n = np.sum(np.logical_and(y_cpu == pred_y_temp, y_cpu!=0))
            pred_n = np.sum(pred_y_temp!=0)
            total_n = np.sum(y_cpu!=0)
            
            precision = correct_n / total_n
            recall = correct_n / pred_n
            f1 = 2 * precision * recall / (precision + recall)
            _loss = loss.numpy()
            time_end=time.time()
            print('batch_id:%4d, train loss :%.4f, f1:%.4f, recall:%.4f, precision:%.4f, time elapsed %4d'
            %(batch_id, _loss, f1, recall, precision, time_end - time_start), flush=True)
    print('epoch %d train complete'%epoch_id)

    sch.step()

    Py_list.append(Py_temp)


    if epoch_id >= t_w:
        nC_points = []
        Py_mean = np.zeros((y_start[-1],))
        for j in range(len(Py_list)):
            Py_mean+=Py_list[j]
        Py_mean/=len(Py_list)   
        for j in range(n_class):
            class_mask = Yt_list[0] == j
            c_n = class_mask.sum()
            c_th = np.sort(Py_mean[class_mask])[-int(c_n * n_C)]
            nC_points.append(np.where(np.logical_and(Py_mean>=c_th,class_mask))[0])
        nC_points = np.concatenate(nC_points)
        L_batch = 1000
        Y_onehot = np.eye(n_class)[Yt_list[-1]]
        zC = paddle.to_tensor(feature_temp[nC_points])
        fC = paddle.to_tensor(Probs_temp[nC_points])
        yC = paddle.to_tensor(Y_onehot[nC_points])
        fCcyC = fC - yC
        lr = 1e-6
        learning_risk = np.zeros((y_start[-1],)) 
        for j in range(int(np.ceil(y_start[-1]/L_batch))):
            i_ind = np.arange(j*L_batch, min(y_start[-1],(j+1)*L_batch))
            zi = paddle.to_tensor(feature_temp[i_ind])
            fi = paddle.to_tensor(Probs_temp[i_ind])
            yi = paddle.to_tensor(Y_onehot[i_ind])
            zixzC = zi @ zC.transpose([1,0])
            part_1_1 = (zixzC + 1) @ fCcyC
            part1 = (part_1_1 * (yi-fi)).sum(axis=-1,keepdim=True)*4*lr/len(nC_points)
            part1_all = part_1_1 * (fi-paddle.ones_like(yi))*4*lr/len(nC_points)

            learning_risk[i_ind] = part1.numpy().ravel()

        select = np.zeros((y_start[-1],),dtype=np.bool8)
        for j in range(n_class):
            class_mask = Yt_list[-1] == j
            c_n = class_mask.sum()
            if c_n > 2:
                c_th = np.sort(learning_risk[class_mask])[-min(int(np.ceil(c_n * n_V)),c_n)]
                select[np.logical_and(learning_risk>=c_th, class_mask)] = True
        
        temp_Yt_noised=np.where(select.ravel(), Pred_temp.ravel(), Yt_list[-1].ravel()).astype(int)
    else:
        temp_Yt_noised = Yt_list[0]
        select = np.zeros((y_start[-1],))
        Yt_list.append(temp_Yt_noised)
    Yt_list.append(temp_Yt_noised)
    print('epoch %d train cleaned, %d samples selected, %d neg samples selected'%
            (epoch_id,np.sum(select),np.sum(np.logical_and(select,is_neg))))
    print('total changed %d'%(np.sum(Yt_list[-1].ravel()!= train_annotated_y.ravel())))
    return _loss, f1, recall, precision


In [14]:
Py_list = []
Yt_list = [train_annotated_y]

acc_list = []
f1_val_list = []
loss_val_list = []
loss_list = []

loss_train_list = []
f1_train_list = []
recall_train_list = []
precision_train_list = []

In [None]:
time_start=time.time()

for epoch_id in range(epoch_n):
    train_dis_f = open('./work/train_annotated_data.pkl','rb')
    train_dis_loader = load_data_new(train_dis_f,)        
    
    _loss, f1, recall, precision = \
    train_predict_model(train_dis_loader,Yt_list,train_annotated_y_start,re_network1,acc_list,)
    loss_train_list.append(_loss)
    f1_train_list.append(f1)
    recall_train_list.append(recall)
    precision_train_list.append(precision)
    loss_val,f1_val = validation_new(re_network1)
    f1_val_list.append(f1_val)
    loss_val_list.append(loss_val)
    gc.collect()