In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from transformers import *
from tqdm import tqdm
import random
import os
import time

In [2]:
train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')

In [3]:
train.head()

Unnamed: 0,question_id,question_title,question_detail,tag_ids
0,0,为什么有的孩子就比同龄的孩子机灵，知道该讨好谁，知道谁比较好说话啊？还知道怎样做不会惹老师生气？,小学的孩子，知道看眉眼高低，懂得老师生气的时候就尽量躲着点？问过家长在家有没有特别教过，家长...,967|8922|240|396
1,1,怎么看待男朋友说玩游戏顺手带妹？,跟男朋友谈了一年左右，刚开始知道他喜欢玩游戏，他说他不带妹，后来暑假刚开始，他疯狂的泡在游戏...,69|109
2,2,武林人士退隐江湖之后会过着怎样的生活？,欢迎各种脑洞~,35|211|230|61|1157
3,3,“一见倾心，再见依然。”这个句子的唯美英文翻译？求大神解答！,,475|15392|2163
4,4,如何看待加拿大国际数学奥林匹克竞赛团队都是华人？,,6803|446|3216|4079|930


In [4]:
num_classes = 25551
model_path = 'hfl/chinese-bert-wwm-ext'

In [5]:
tokenizer = BertTokenizer.from_pretrained(model_path)

In [6]:
class MyDataset(Dataset):
    def __init__(self, dataframe, maxlen=256, test=False):
        self.df = dataframe
        self.maxlen = maxlen
        self.test = test
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        text = str(self.df.question_title.values[idx]) + str(self.df.question_detail.values[idx])

        encoding = tokenizer(text,padding='max_length',truncation=True,max_length=self.maxlen,return_tensors='pt')
        
        input_ids = encoding['input_ids'][0]
        attention_mask = encoding['attention_mask'][0]
        
        if self.test:
            return input_ids, attention_mask
        else:
            tags = self.df.tag_ids.values[idx].split('|')
            tags = [int(x)-1 for x in tags]
            label = torch.zeros((num_classes,))
            label[tags] = 1

            return input_ids, attention_mask, label

In [7]:
train_set = MyDataset(train)
test_set = MyDataset(test,test=True)

In [8]:
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False)

In [9]:
class Model(nn.Module):
    def __init__(self, ):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(model_path)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)
        
    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids, attention_mask=attention_mask)[-1]
        output = self.fc(output)
        output = torch.sigmoid(output)
        return output

In [10]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
        
def train_model(model, train_loader):
    model.train() 
    
    losses = AverageMeter()
    
    optimizer.zero_grad()
    
    tk = tqdm(train_loader, total=len(train_loader), position=0, leave=True)
    
    for idx, (input_ids, attention_mask, y) in enumerate(tk):
        input_ids, attention_mask, y = input_ids.cuda(), attention_mask.cuda(), y.cuda()

        output = model(input_ids, attention_mask)

        loss = criterion(output, y) 
        loss.backward()

        optimizer.step() 
        optimizer.zero_grad() 

        losses.update(loss.item(), y.size(0))

        tk.set_postfix(loss=losses.avg)
        
    return losses.avg

In [11]:
model = Model().cuda()

In [12]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)
criterion = nn.BCELoss()

In [None]:
for epoch in range(50):
    train_loss = train_model(model, train_loader)
    torch.save(model.state_dict(), 'model.pt')

In [None]:
result = []

model.eval() 
    
tk = tqdm(test_loader, total=len(test_loader), position=0, leave=True)

with torch.no_grad():
    
    for idx, (input_ids, attention_mask) in enumerate(tk):
        input_ids, attention_mask = input_ids.cuda(), attention_mask.cuda()

        output = model(input_ids, attention_mask)
        
        for res in output: #后处理，找大于0.5的类别（阈值可以微调），如果多了就取TOP5，如果没有就取TOP1
            _,res1 = torch.topk(res,5)
            res1 = res1.cpu().numpy()
            
            res2 = torch.where(res>0.5)[0]
            res2 = res2.cpu().numpy()
            
            if len(res2) > 5:
                result.append(res1)
            elif len(res2) == 0:
                result.append(res1[0])
            else:
                result.append(res2)

In [None]:
with open('submission.csv','w')as f:
    for i in range(len(result)):
        f.write(str(i)+',')
        res = [str(x+1) for x in result[i]]
        if len(res)<5:
            res += ['-1']*(5-len(res))
        f.write(','.join(res))
        f.write('\n')