In [None]:
import pickle
import time
import torch
from torch.optim import lr_scheduler
from torch.optim.adam import Adam
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertModel
import os
from tqdm import tqdm
import torch.nn as nn
from Dataset import ServeNetDataset
from data_pre import load_data_train, load_data_test
import numpy as np
import pandas as pd
from sklearn.metrics import classification_report, accuracy_score
from sklearn import metrics

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

UNCASED = './bert-base-uncased'
VOCAB = 'vocab.txt'
epochs = 40
SEED = 123
LEARNING_RATE = 0.01
WEIGHT_DECAY = 0.01
EPSILON = 1e-8
BATCH_SIZE=128
CLASS_NUM=250

In [None]:
def evaluteTop1(model, dataLoader):
    model.eval()
    correct = 0
    total = 0
    class_correct = list(0. for i in range(CLASS_NUM))
    class_total = list(0. for i in range(CLASS_NUM))
    with torch.no_grad():
        for data in dataLoader:
            input_tokens_name = data[3].cuda()
            segment_ids_name = data[4].cuda()
            input_masks_name = data[5].cuda()

            input_tokens_descriptions = data[0].cuda()
            segment_ids_descriptions = data[1].cuda()
            input_masks_descriptions = data[2].cuda()
            label = data[6].cuda()

            outputs = model((input_tokens_name, segment_ids_name, input_masks_name),
                            (input_tokens_descriptions, segment_ids_descriptions, input_masks_descriptions))

            _, predicted = torch.max(outputs, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
            
            #each class accuracy
            c = (predicted == label).squeeze()
            for i in range(len(label)):
                labels = label[i]
                class_correct[labels] += c[i].item()
                class_total[labels] += 1


    print('each class accuracy of: ' )
    for i in range(CLASS_NUM):
        #print('Accuracy of ======' ,100 * class_correct[i] / class_total[i])
        print(100 * class_correct[i] / class_total[i])
    
    print('total class_total: ')
    for i in range(CLASS_NUM):
        print(class_total[i])
        
    return 100 * correct / total


def evaluteTop5(model, dataLoader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataLoader:
            input_tokens_name = data[3].cuda()
            segment_ids_name = data[4].cuda()
            input_masks_name = data[5].cuda()

            input_tokens_descriptions = data[0].cuda()
            segment_ids_descriptions = data[1].cuda()
            input_masks_descriptions = data[2].cuda()
            label = data[6].cuda()
            outputs = model((input_tokens_name, segment_ids_name, input_masks_name),
                            (input_tokens_descriptions, segment_ids_descriptions, input_masks_descriptions))
            maxk = max((1, 5))
            y_resize = label.view(-1, 1)
            _, pred = outputs.topk(maxk, 1, True, True)
            total += label.size(0)
            correct += torch.eq(pred, y_resize).sum().float().item()
   
    return 100 * correct / total

class weighted_sum(nn.Module):
    def __init__(self):
        super(weighted_sum, self).__init__()
        self.w1 = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
        self.w2 = nn.Parameter(torch.FloatTensor(1), requires_grad=True)


    def forward(self, input1, input2):
        return input1 * self.w1 + input2 * self.w2


In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import numpy as np
import math

class MutliHead(nn.Module):

    def __init__(self,
                 feat_dim=1024,
                 num_classes=250,
                 use_effect=True,
                 num_head=2, #2, 4
                 tau=16.0, #16, 32
                 alpha=0, # 0, 1, 1.5, 3
                 gamma=0.03125):
        super(MutliHead, self).__init__()
     
        self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim), requires_grad=True)
        self.scale = tau / num_head  
        self.norm_scale = gamma       
        self.alpha = alpha            
        self.num_head = num_head
        self.head_dim = feat_dim // num_head
        self.use_effect = use_effect

        self.MU = 1.0 - (1 - 0.9) * 0.02

        self.causal_embed = nn.Parameter(torch.FloatTensor(1, feat_dim).fill_(1e-10), requires_grad=False)
        
        self.reset_parameters(self.weight)

    def reset_parameters(self, weight):
        nn.init.normal_(weight, 0, 0.01)

    def get_cos_sin(self, x, y):
        cos_val = (x * y).sum(-1, keepdim=True) / torch.norm(x, 2, 1, keepdim=True) / torch.norm(y, 2, 1, keepdim=True)
        sin_val = (1 - cos_val * cos_val).sqrt()
        return cos_val, sin_val

    def multi_head_call(self, func, x, weight=None):
        assert len(x.shape) == 2
        x_list = torch.split(x, self.head_dim, dim=1)
        if weight:
            y_list = [func(item, weight) for item in x_list]
        else:
            y_list = [func(item) for item in x_list]
        assert len(x_list) == self.num_head
        assert len(y_list) == self.num_head
        return torch.cat(y_list, dim=1)

    def l2_norm(self, x):
        normed_x = x / (torch.norm(x, 2, 1, keepdim=True) + 1e-8)
        return normed_x
 
    def causal_norm(self, x, weight):
        norm= torch.norm(x, 2, 1, keepdim=True)
        normed_x = x / (norm + weight)
        return normed_x

    def init_weights(self):
        self.reset_parameters(self.weight)

    def forward(self, x):
        normed_w = self.multi_head_call(self.causal_norm, self.weight, weight=self.norm_scale)
        normed_x = self.multi_head_call(self.l2_norm, x)
        y = torch.mm(normed_x * self.scale, normed_w.t())
        
        return y


In [None]:
class ServeNet(torch.nn.Module):
    def __init__(self, hiddenSize,CLASS_NUM):
        super(ServeNet, self).__init__()
        self.hiddenSize = hiddenSize

        self.bert_name = BertModel.from_pretrained(UNCASED)
        self.bert_description = BertModel.from_pretrained(UNCASED)

        self.name_liner = nn.Linear(in_features=self.hiddenSize, out_features=1024)
        self.name_ReLU = nn.ReLU()
        self.name_Dropout = nn.Dropout(p=0.1)
 
        self.lstm = nn.LSTM(input_size=self.hiddenSize, hidden_size=512, num_layers=1, batch_first=True,
                            bidirectional=True)
       
        self.weight_sum = weighted_sum()
        self.mutliHead = MutliHead()

    def forward(self, names, descriptions):
        self.lstm.flatten_parameters()
        input_tokens_names, segment_ids_names, input_masks_names = names
        input_tokens_descriptions, segment_ids_descriptions, input_masks_descriptions = descriptions

        # name
        name_bert_output = self.bert_name(input_tokens_names, segment_ids_names,
                                     input_masks_names)
        # Feature for Name
        name_features = self.name_liner(name_bert_output[1])
        name_features = self.name_ReLU(name_features)
        name_features = self.name_Dropout(name_features)
       
        # description
        description_bert_output = self.bert_description(input_tokens_descriptions, segment_ids_descriptions,
                                                   input_masks_descriptions)

        description_bert_feature=description_bert_output[0]

        # LSTM
        packed_output, (hidden, cell) = self.lstm(description_bert_feature)
        hidden = torch.cat((cell[0, :, :], cell[1, :, :]), dim=1)
        #hidden = torch.cat((hidden[0, :, :], hidden[1, :, :]), dim=1)

        # sum
        all_features = self.weight_sum(name_features, hidden)
        output = self.mutliHead(all_features)
       
        return output


In [None]:
if __name__ == "__main__":

    train_data = load_data_train(CLASS_NUM)
    test_data = load_data_test(CLASS_NUM)
    # train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE)
    test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE)

    model = ServeNet(768,CLASS_NUM)
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    model.train()

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
    # optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

    for epoch in range(epochs):
        print("Epoch:{},lr:{}".format(str(epoch+1),str(optimizer.state_dict()['param_groups'][0]['lr'])))
        scheduler.step()
        model.train()
        for data in tqdm(train_dataloader):

            input_tokens_name = data[3].cuda()
            segment_ids_name = data[4].cuda()
            input_masks_name = data[5].cuda()

            input_tokens_descriptions = data[0].cuda()
            segment_ids_descriptions = data[1].cuda()
            input_masks_descriptions = data[2].cuda()
            label = data[6].cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model((input_tokens_name, segment_ids_name, input_masks_name),
                            (input_tokens_descriptions, segment_ids_descriptions, input_masks_descriptions))

            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()
            
        print("=======>top1 acc on the test:{}".format(str(evaluteTop1(model, test_dataloader))))
        print("=======>top5 acc on the test:{}".format(str(evaluteTop5(model, test_dataloader))))