In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

#DGL
import dgl

from dataset import DocData
from model import GATON,MyLoss
from evaluate import computeClassificationMetric,computeCoherence
import random

from apex import amp

`fused_weight_gradient_mlp_cuda` module not found. gradient accumulation fusion with weight gradient computation disabled.


In [5]:
seed=2022
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [6]:
#超参
topic_num=20
graph_input_dim=64
graph_hidden_dim=128
device="cuda:0"

head_num=2

epochs=20

data=DocData("/home/v-ruiruiwang/notebooks/code/GATON/data/20NG_mindf_97_vocab_2004_pretrain.pkl")

In [7]:
model = GATON(
    data.graph,
    topic_num,
    data.vocabulary_size,
    data.word_embedding_size,
    graph_input_dim,
    graph_hidden_dim,
    head_num
    ).to(device)

data.graph.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-3,weight_decay=0.005)
loss_fn=MyLoss()

model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”，不是“零一”



model.train()

for epoch in range(epochs):

    doc_input = data.doc_word_frequency.to(device) # 初始的 doc 表示 
    word_input = data.word_embeddings.to(device)    #初始的 word 表示

    doc_topic_prob,word_topic_prob = model(doc_input,word_input)

    #计算loss
    doc_word_occ = torch.matmul(doc_topic_prob,word_topic_prob.permute(1,0))
    recon_loss = loss_fn(doc_word_occ,doc_input)
    total_loss=recon_loss
    optimizer.zero_grad()
    # total_loss.backward()
    
    # amp
    with amp.scale_loss(total_loss, optimizer) as scaled_loss:
        scaled_loss.backward()
    
    optimizer.step()

    print("==========================================================================")
    print("epoch start ====================================================================")
    print('cur reconstruct loss:',recon_loss.item())
    # print('cur classfication loss:',class_loss.item())
    print('cur total  loss:',total_loss.item())
    
    model.eval()
    with torch.no_grad():
        word_topic_dis = word_topic_prob.cpu().numpy().transpose()
        doc_topic_dis = doc_topic_prob.cpu().argmax(dim=1).numpy()
    #验证分类
    computeClassificationMetric(doc_topic_dis,np.array(data.labels),True)
    #验证Topic Coherence
    computeCoherence(word_topic_dis,data.ins_tr,data.tgt_keys,data.idx2word,10,True)

    print("epoch end====================================================================")
    print("==========================================================================")
    

cur reconstruct loss: 82712.7578125
cur total  loss: 82712.7578125
f1 macro: 0.008963552550904544 f1 micro: 0.052646291464778235 precision: 0.052646291464778235 accuracy: 0.052646291464778235 recall 0.052646291464778235
Average Topic Coherence = 0.116
Average  gaton Topic Coherence = -269.989
0.15551234996416038
['atheism', 'atheists', 'religion', 'atheist', 'god']
['covered', 'due', 'demand', 'users', 'sleep', 'fast', 'explain', 'dave', 'today', 'colorado']
0.17886358127807678
['graphics', 'image', 'images', 'gif', 'format']
['side', 'due', 'demand', 'increased', 'users', 'connect', 'dave', 'fast', 'explain', 'today']
0.12602252297565253
['windows', 'dos', 'file', 'ms', 'microsoft']
['os', 'side', 'covered', 'due', 'demand', 'increased', 'users', 'connect', 'sleep', 'explain']
0.16636105927393408
['ide', 'controller', 'card', 'bus', 'drive']
['side', 'due', 'demand', 'users', 'increased', 'connect', 'sleep', 'explain', 'fast', 'dave']
0.10086242697862302
['mac', 'apple', 'monitor', 's