In [1]:
"""
思路: 在TR模型的基础上, 融合用户的文本特征, 以异质图Heter-GAT的方式融合用户特征, 做信息传播预测任务
方法:
1. 整理原始数据, 构建用户有向关联网络, 并根据原始文本内容计算用户文本嵌入向量
2. 考虑节点类型为User和Tweet, 边类型为U-U和U-T, 分别从用户特征和文本特征的角度通过GAT网络融合邻域节点特征;
   Heter-GAT模型的输出为(N,|Rs|+1,D')维度, 模型后面需要接一个全连接层FC=(|Rs|+1)*D'->2, 损失函数保持为NLL-Loss
3. 可视化局部邻域, 观察不同注意力头、不同异质图邻域卷积的偏向
"""

"\n思路: 在TR模型的基础上, 融合用户的文本特征, 以异质图Heter-GAT的方式融合用户特征, 做信息传播预测任务\n方法:\n1. 整理原始数据, 构建用户有向关联网络, 并根据原始文本内容计算用户文本嵌入向量\n2. 考虑节点类型为User和Tweet, 边类型为U-U和U-T, 分别从用户特征和文本特征的角度通过GAT网络融合邻域节点特征;\n   Heter-GAT模型的输出为(N,|Rs|+1,D')维度, 模型后面需要接一个全连接层FC=(|Rs|+1)*D'->2, 损失函数保持为NLL-Loss\n3. 可视化局部邻域, 观察不同注意力头、不同异质图邻域卷积的偏向\n"

In [2]:
# import sys
# import os
# sys.path.append(os.path.dirname(os.getcwd()))
# from lib.log import logger
# from utils import load_pickle, save_pickle, summarize_distribution, find_rt_bound, HeterSubGraphSample
# from lib.utils import get_node_types, extend_edges, create_sparse, get_sparse_tensor
# import numpy as np

In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))
from lib.log import logger
from lib.utils import get_sparse_tensor
from utils import HeterSubGraphSample, load_pickle, save_pickle, init_args, ChunkSampler, HeterGraphDataset, sparse_batch_collate
from model import HeterGraphAttentionNetwork
import numpy as np
import time
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, precision_recall_curve
from tensorboard_logger import tensorboard_logger
from scipy.sparse import csr_matrix


2022-10-21 18:23:26,046 Note: NumExpr detected 56 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2022-10-21 18:23:26,048 NumExpr defaulting to 8 threads.


In [2]:
samples = load_pickle("/root/data/HeterGAT/stages/hs_subg483_inf_40_1718027_deg_18_483_ego_20_neg_1_restart_20/heter_samples.p")


In [11]:
samples.vertex_ids[0]

array([  39949,   23831,   15772,    4542,   21571,   25925,   40133,
         20049,   16095,   36448,   20447,    1637,   38896,    3573,
         12535,   35704,   15481,   15994,    2427,   40452, 4369418,
       4369419, 4369420, 4369423, 4531220, 4531221, 4531222, 4531225,
       4531226, 4531227, 4369443, 4369444, 4531243, 4531244, 4531246,
       4531247, 4369461, 4531253, 4531265, 4503658, 4369483, 4369484,
       4369487, 4503633, 4503634, 4531284, 4503636, 4531285, 4503639,
       4503640, 4503641, 4503642, 4531291, 4503644, 4503645, 4503643,
       5553759, 5553760, 4503648, 4503650, 5553763, 5553764, 4503653,
       4503652, 4531298, 5553768, 4503657, 5553770, 4503659, 4503651,
       5553762, 4503655, 4503663, 4503662, 4503656, 4503666, 4503665,
       5553771, 4503669, 4503661, 5553775, 4503672, 5553776, 4503674,
       4503673, 4503676, 4503675, 4503664, 4503679, 4503677, 4503681,
       4503667, 5553806, 4503668, 5553808, 5553810, 5553812, 5553772,
       4503670, 5553

In [2]:
hs = load_pickle("/root/data/HeterGAT/stages/hs_subg483_inf_40_1718027_deg_18_483_ego_50_neg_1_restart_20/heter_samples.p")


In [5]:
hs.labels

(49926, 2)

In [66]:
# NOTE: Fake Digg Heter Dataset
from torch.utils.data import Dataset
from utils import SubGraphSample, load_w2v_feature

class DiggDataset(Dataset):
    def __init__(self, samples: SubGraphSample, embedding) -> None:
        super().__init__()
        self.adjs = samples.adj_matrices
        self.labels = samples.labels
        self.feats = samples.influence_features
        self.vertex_ids = samples.vertex_ids
        self.concact_feats(embedding)
    def concact_feats(self, embedding):
        feats = []
        for idx, vertex_ids in enumerate(self.vertex_ids):
            emb_feats = [embedding[user] for user in vertex_ids]
            feats.append(np.concatenate((self.feats[idx], emb_feats), axis=1))
        self.feats = np.array(feats)
        logger.info(self.feats.shape)
    def __len__(self):
        return self.labels.shape[0]
    def __getitem__(self, index):
        return self.adjs[index], self.labels[index], self.feats[index]

def collate_fn2(batch:list): 
    """
    Collate function which to transform scipy coo matrix to pytorch sparse tensor
    """
    adjs_batch, labels_batch, feats_batch = zip(*batch)
    adjs_batch = torch.FloatTensor(np.array(adjs_batch))
    
    if type(labels_batch[0]).__module__ == 'numpy':
        # NOTE: https://stackoverflow.com/questions/69742930/runtimeerror-nll-loss-forward-reduce-cuda-kernel-2d-index-not-implemented-for
        labels_batch = torch.LongTensor(labels_batch)
    
    if type(feats_batch[0]).__module__ == 'numpy':
        feats_batch = torch.FloatTensor(np.array(feats_batch))
    return adjs_batch, labels_batch, feats_batch

def digg_load_dataset(train_ratio=60, valid_ratio=20, batch_size=256):
    embedding_path = "/root/Lab_Related/data/Heter-GAT/Classic/deepwalk/deepwalk_added.emb_64"
    vertices = np.load("/root/TR-pptusn/DeepInf-preprocess/preprocess/stages_op_inf_100_1k/vertex_id.npy")
    max_vertex_idx = np.max(vertices)
    embedding = load_w2v_feature(embedding_path, max_vertex_idx)
    # embedding = torch.FloatTensor(embedding)

    samples = SubGraphSample(
        adj_matrices=np.load("/root/TR-pptusn/DeepInf-preprocess/preprocess/stages_op_inf_100_1k/adjacency_matrix.npy"),
        influence_features=np.load("/root/TR-pptusn/DeepInf-preprocess/preprocess/stages_op_inf_100_1k/influence_feature.npy"),
        vertex_ids=np.load("/root/TR-pptusn/DeepInf-preprocess/preprocess/stages_op_inf_100_1k/vertex_id.npy"),
        labels=np.load("/root/TR-pptusn/DeepInf-preprocess/preprocess/stages_op_inf_100_1k/label.npy")
    )
    dataset = DiggDataset(samples, embedding)
    nb_samples    = len(dataset)
    
    train_start,  valid_start, test_start = 0, int(nb_samples*train_ratio/100), int(nb_samples*(train_ratio+valid_ratio)/100)
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=ChunkSampler(valid_start-train_start, 0), collate_fn=collate_fn2)
    valid_loader = DataLoader(dataset, batch_size=batch_size, sampler=ChunkSampler(test_start-valid_start, valid_start), collate_fn=collate_fn2)
    test_loader  = DataLoader(dataset, batch_size=batch_size, sampler=ChunkSampler(nb_samples - test_start, test_start), collate_fn=collate_fn2)
    logger.info(f"Finish Loading Dataset... train={len(train_loader)}, valid={len(valid_loader)}, test={len(test_loader)}")

    return samples, train_loader, valid_loader, test_loader

In [67]:
GPU_MODEL = 'cuda:2'

# 1. 
def load_dataset(data_filepath:str, train_ratio:float, valid_ratio:float, batch_size:int):
    # heter_samples = load_pickle(os.path.join(data_dirpath, "heter_samples_tensor.p"))
    heter_samples = load_pickle(data_filepath)
    dataset       = HeterGraphDataset(heter_samples=heter_samples)
    nb_samples    = len(dataset)
    
    train_start,  valid_start, test_start = 0, int(nb_samples*train_ratio/100), int(nb_samples*(train_ratio+valid_ratio)/100)
    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=ChunkSampler(valid_start-train_start, 0), collate_fn=sparse_batch_collate)
    valid_loader = DataLoader(dataset, batch_size=batch_size, sampler=ChunkSampler(test_start-valid_start, valid_start), collate_fn=sparse_batch_collate)
    test_loader  = DataLoader(dataset, batch_size=batch_size, sampler=ChunkSampler(nb_samples - test_start, test_start), collate_fn=sparse_batch_collate)
    logger.info(f"Finish Loading Dataset... train={len(train_loader)}, valid={len(valid_loader)}, test={len(test_loader)}")

    return heter_samples, train_loader, valid_loader, test_loader

# args = init_args()
args = {
    "train_ratio": 60,
    "valid_ratio": 20,
    "batch": 256,
    # "nb_user": 50,
    "class_weight_balanced": True,
    "hidden_units": "16,16",
    "heads": "8,8",
    "cuda": True,
    "lr": 0.1,
    "weight_decay": 5e-4,
    "dropout": 0.2,
    "seed": 42,
    "epochs": 100,
    "check_point": 2,
    # "file_dir": "heter_samples_ratio1/0/heter_samples.p",
    "file_dir": "hs_new_ratio1/heter_samples.p",
}
# heter_samples, train_loader, valid_loader, test_loader = load_dataset(args["file_dir"], args["train_ratio"], args["valid_ratio"], args["batch"])
samples, train_loader, valid_loader, test_loader = digg_load_dataset()
nb_samples = len(samples)
nb_classes = 2
class_weight = torch.FloatTensor(nb_samples / (nb_classes*np.bincount(samples.labels))) if args["class_weight_balanced"] else torch.ones(nb_classes)
nb_user = 50
n_units = [samples.influence_features.shape[2]+64]+[int(x) for x in args["hidden_units"].strip().split(",")]
n_heads = [int(x) for x in args["heads"].strip().split(",")]
logger.info(f"nb_samples={nb_samples}, class_weight={class_weight[0]:.2f}:{class_weight[1]:.2f}, n_units={n_units}, n_heads={n_heads}")
# logger.info("class_weight=%.2f:%.2f", class_weight[0], class_weight[1])

2022-10-12 15:42:11,284 n=204955, d=64
2022-10-12 15:42:21,005 (38152, 50, 66)
2022-10-12 15:42:21,017 Finish Loading Dataset... train=90, valid=30, test=30
2022-10-12 15:42:21,184 nb_samples=38152, class_weight=0.57:4.05, n_units=[66, 16, 16], n_heads=[8, 8]


In [68]:
# 2. 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
import copy
from model import BatchSparseMultiHeadGraphAttention, BatchMultiHeadGraphAttention, BatchAdditiveAttention

class HeterGraphAttentionNetwork2(nn.Module):
    def __init__(
        self, n_user, nb_node_kinds=2, nb_loop_nodes=[50,1050],
        nb_classes=2, n_units=[25,64], n_heads=[3],
        attn_dropout=0.5, dropout=0.1, 
        d2=64, gpu_device_ids=[] 
    ) -> None:
        """
        Args:
            gpu_device_ids(List[int], default=[]): 采用Model Parallel方法, 主要使多头注意力可以在不同GPU上运行, 
                模型以数据所属GPU为例, 先在该参数指定的不同GPU上执行单一注意力头, 再在最后将所有注意力头的运行结果复制回数据所属的主GPU上
        """
        super().__init__()

        self.n_layer = len(n_units) - 1
        self.dropout = dropout
        self.gpu_device_ids = gpu_device_ids

        self.d = n_units[0]
        self.d1 = n_units[1]
        self.d2 = n_units[1]
        self.n_user = n_user

        self.layer_stack = nn.ModuleList()
        for hidx in range(nb_node_kinds):
            layer_stack = nn.ModuleList()
            for i in range(self.n_layer):
                # consider multi head from last layer
                f_in = n_units[i] * n_heads[i - 1] if i else n_units[i]
                layer_stack.append(
                    # BatchMultiHeadGraphAttention(nb_heads=n_heads[i], nb_in_feats=f_in, nb_out_feats=n_units[i + 1], 
                    #     nb_loop_nodes=nb_loop_nodes[hidx], attn_dropout=attn_dropout)
                    BatchMultiHeadGraphAttention(n_head=n_heads[i], f_in=f_in, f_out=n_units[i+1], attn_dropout=attn_dropout)
                )
            self.layer_stack.append(layer_stack)
        self.additive_attention = BatchAdditiveAttention(d=self.d, d1=self.d1, d2=self.d2)
        self.fc_layer = nn.Linear(in_features=self.d1*(nb_node_kinds+1), out_features=nb_classes)
    
    def forward(self, h, hadj):
        # NOTE: h: (bs, N, fin), hadj: (|Rs|, bs, N, N)
        bs, n = h.shape[:2]
        heter_embs = []
        for heter_idx, layer_stack in enumerate(self.layer_stack):
            x = copy.deepcopy(h)
            for i, gat_layer in enumerate(layer_stack):
                x = gat_layer(x, hadj[heter_idx]) # output: (bs, n_head, n, f_out)
                if i + 1 == self.n_layer:
                    x = x.mean(dim=-3) # (bs, n_head, n, f_out) -> (bs, n, f_out)
                else:
                    x = F.elu(x.reshape(bs, n, -1))
                    x = F.dropout(x, self.dropout, training=self.training)
            heter_embs.append(x[:,:self.n_user].unsqueeze(-2)) # (bs, Nu, 1, f_out)
        type_aware_emb = torch.cat(heter_embs, dim=-2) # (bs, Nu, |Rs|, D')
        type_fusion_emb = self.additive_attention(h[:,:self.n_user], type_aware_emb) # (bs, Nu, 1, D')
        ret = self.fc_layer(
            torch.cat((type_fusion_emb, type_aware_emb), dim=-2).reshape(bs, self.n_user,-1) # (bs, Nu, |Rs|+1, D') -> (bs, Nu, (|Rs|+1)*D')
        ) #  (bs, Nu, nb_classes)
        return F.log_softmax(ret, dim=-1)

model = HeterGraphAttentionNetwork2(n_user=nb_user, n_units=n_units, nb_classes=nb_classes, n_heads=n_heads, dropout=args["dropout"])
if args["cuda"]:
    # model.cuda()
    model.to(GPU_MODEL)
    # class_weight = class_weight.cuda()
    class_weight = class_weight.to(GPU_MODEL)
# params = [{'params': model.parameters()}]
optimizer = optim.Adagrad(model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"])

In [69]:
def train(epoch, train_loader, valid_loader, test_loader, log_desc='train_'):
    model.train()

    loss = 0.
    total = 0.
    for i_batch, batch in enumerate(train_loader):
        # if i_batch % 10 == 0:
        #     logger.info(f"i_batch={i_batch}")
        adjs, labels, feats = batch
        bs = adjs.size(0)

        if args["cuda"]:
            adjs    = adjs.to(GPU_MODEL)
            labels  = labels.to(GPU_MODEL)
            feats   = feats.to(GPU_MODEL)

        optimizer.zero_grad()
        output = model(feats, torch.stack([adjs, adjs]))
        output = output[:,-1,:] # choose last user

        loss_train = F.nll_loss(output, labels, class_weight)
        loss += bs * loss_train.item()
        total += bs
        loss_train.backward()
        optimizer.step()
    logger.info("train loss in this epoch %f", loss / total)

train(0, train_loader, valid_loader, test_loader)

2022-10-12 15:46:18,982 train loss in this epoch 1.182631


In [None]:
# for i_batch, batch in enumerate(train_loader):
#     uu_adjs, ut_adjs, labels, feats = batch
#     bs = uu_adjs.size(0)

#     # feats = feats[:,:100]
#     # labels = labels
#     # nw_uu_adjs, nw_ut_adjs = [], []
#     # for idx in range(len(uu_adjs)):
#     #     nw_uu_adjs.append(torch.LongTensor(uu_adjs[idx].to_dense().numpy()[:100,:100]+np.eye(100)))
#     #     nw_ut_adjs.append(torch.LongTensor(ut_adjs[idx].to_dense().numpy()[:100,:100]+np.eye(100)))
#     # uu_adjs = torch.stack(nw_uu_adjs)
#     # ut_adjs = torch.stack(nw_ut_adjs)

#     if args["cuda"]:
#         uu_adjs, ut_adjs, labels, feats = uu_adjs.to(GPU_MODEL), ut_adjs.to(GPU_MODEL), labels.to(GPU_MODEL), feats.to(GPU_MODEL)
#     break

# model = HeterGraphAttentionNetwork(n_user=nb_user, n_units=n_units, nb_classes=nb_classes, n_heads=n_heads, dropout=args["dropout"])
# if args["cuda"]:
#     model.to(GPU_MODEL)
#     class_weight = class_weight.to(GPU_MODEL)
# optimizer = optim.Adagrad(model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"])

# output = model(feats, torch.stack([uu_adjs, ut_adjs]))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from model import BatchSparseMultiHeadGraphAttention, BatchAdditiveAttention, BatchMultiHeadGraphAttention
import copy

class HeterGraphAttentionNetwork2(nn.Module):
    def __init__(
        self, n_user, nb_node_kinds=2, nb_loop_nodes=[50,1050],
        nb_classes=2, n_units=[25,64], n_heads=[3],
        attn_dropout=0.5, dropout=0.1, 
        d2=64, gpu_device_ids=[] 
    ) -> None:
        super().__init__()

        self.n_layer = len(n_units) - 1
        self.dropout = dropout
        self.gpu_device_ids = gpu_device_ids

        self.d = n_units[0]
        self.d1 = n_units[1]
        self.d2 = d2
        self.n_user = n_user

        self.layer_stack = nn.ModuleList()
        for hidx in range(nb_node_kinds):
            layer_stack = nn.ModuleList()
            for i in range(self.n_layer):
                # consider multi head from last layer
                f_in = n_units[i] * n_heads[i - 1] if i else n_units[i]
                layer_stack.append(
                    BatchMultiHeadGraphAttention(n_head=n_heads[i], f_in=f_in, f_out=n_units[i + 1], 
                        attn_dropout=attn_dropout)
                )
            self.layer_stack.append(layer_stack)
        self.additive_attention = BatchAdditiveAttention(d=self.d, d1=self.d1, d2=self.d2)
        self.fc_layer = nn.Linear(in_features=self.d1*(nb_node_kinds+1), out_features=nb_classes)
    
    def forward(self, h, hadj):
        # NOTE: h: (bs, N, fin), hadj: (|Rs|, bs, N, N)
        bs, n = h.shape[:2]
        heter_embs = []
        for heter_idx, layer_stack in enumerate(self.layer_stack):
            x = copy.deepcopy(h)
            for i, gat_layer in enumerate(layer_stack):
                x = gat_layer(x, hadj[heter_idx]) # output: (bs, n_head, n, f_out)
                if i + 1 == self.n_layer:
                    x = x.mean(dim=-3) # (bs, n_head, n, f_out) -> (bs, n, f_out)
                else:
                    x = F.elu(x.reshape(bs, n, -1))
                    x = F.dropout(x, self.dropout, training=self.training)
            logger.info(x)
            heter_embs.append(x[:,:self.n_user].unsqueeze(-2)) # (bs, Nu, 1, f_out)
        type_aware_emb = torch.cat(heter_embs, dim=-2) # (bs, Nu, |Rs|, D')
        type_fusion_emb = self.additive_attention(h[:,:self.n_user], type_aware_emb) # (bs, Nu, 1, D')
        ret = self.fc_layer(
            torch.cat((type_fusion_emb, type_aware_emb), dim=-2).reshape(bs, self.n_user,-1) # (bs, Nu, |Rs|+1, D') -> (bs, Nu, (|Rs|+1)*D')
        ) #  (bs, Nu, nb_classes)
        return F.log_softmax(ret, dim=-1)

model = HeterGraphAttentionNetwork2(n_user=nb_user, n_units=n_units, nb_classes=nb_classes, n_heads=n_heads, dropout=args["dropout"])
if args["cuda"]:
    model.to(GPU_MODEL)
    class_weight = class_weight.to(GPU_MODEL)
optimizer = optim.Adagrad(model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"])

output = model(feats, torch.stack([uu_adjs, ut_adjs]))

In [8]:
# 2. 
model = HeterGraphAttentionNetwork(n_user=nb_user, n_units=n_units, nb_classes=nb_classes, n_heads=n_heads, dropout=args["dropout"])
if args["cuda"]:
    # model.cuda()
    model.to(GPU_MODEL)
    # class_weight = class_weight.cuda()
    class_weight = class_weight.to(GPU_MODEL)
# params = [{'params': model.parameters()}]
optimizer = optim.Adagrad(model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"])


In [4]:
# logger.info(f"out Allocated: {torch.cuda.memory_reserved(int(GPU_MODEL[-1]))/1024**3}")

2022-09-22 12:14:19,308 out Allocated: 0.001953125


In [11]:
def evaluate(epoch, loader, thr=None, return_best_thr=False, log_desc='valid_'):
    model.eval()
    total = 0.
    loss, prec, rec, f1 = 0., 0., 0., 0.
    y_true, y_pred, y_score = [], [], []
    for i_batch, batch in enumerate(loader):
        uu_adjs, ut_adjs, labels, feats = batch
        bs = uu_adjs.size(0)

        if args["cuda"]:
            uu_adjs = uu_adjs.to(GPU_MODEL)
            ut_adjs = ut_adjs.to(GPU_MODEL)
            labels  = labels.to(GPU_MODEL)
            feats   = feats.to(GPU_MODEL)

        output = model(feats, torch.stack([uu_adjs, ut_adjs]))
        output = output[:,-1,:] # choose last user

        loss_batch = F.nll_loss(output, labels, class_weight)
        loss += bs * loss_batch.item()

        y_true += labels.data.tolist()
        # 返回output中每行最大值的索引
        y_pred += output.max(1)[1].data.tolist()
        y_score += output[:, 1].data.tolist()
        total += bs

    model.train()

    if thr is not None:
        logger.info("using threshold %.4f", thr)
        y_score = np.array(y_score)
        y_pred = np.zeros_like(y_score)
        y_pred[y_score > thr] = 1

    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")
    auc = roc_auc_score(y_true, y_score)
    logger.info("%sloss: %.4f AUC: %.4f Prec: %.4f Rec: %.4f F1: %.4f",
            log_desc, loss / total, auc, prec, rec, f1)

    if return_best_thr:
        precs, recs, thrs = precision_recall_curve(y_true, y_score)
        f1s = 2 * precs * recs / (precs + recs)
        f1s = f1s[:-1]
        thrs = thrs[~np.isnan(f1s)]
        f1s = f1s[~np.isnan(f1s)]
        best_thr = thrs[np.argmax(f1s)]
        logger.info("best threshold=%4f, f1=%.4f", best_thr, np.max(f1s))
        return best_thr
    else:
        return None

def train(epoch, train_loader, valid_loader, test_loader, log_desc='train_'):
    model.train()

    loss = 0.
    total = 0.
    for i_batch, batch in enumerate(train_loader):
        # if i_batch % 10 == 0:
        #     logger.info(f"i_batch={i_batch}")
        uu_adjs, ut_adjs, labels, feats = batch
        bs = uu_adjs.size(0)

        if args["cuda"]:
            uu_adjs = uu_adjs.to(GPU_MODEL)
            ut_adjs = ut_adjs.to(GPU_MODEL)
            labels  = labels.to(GPU_MODEL)
            feats   = feats.to(GPU_MODEL)

        optimizer.zero_grad()
        output = model(torch.rand((feats.shape)).to(GPU_MODEL), torch.stack([uu_adjs, ut_adjs]))
        logger.info(output.shape)
        output = output[:,-1,:] # choose last user

        loss_train = F.nll_loss(output, labels, class_weight)
        loss += bs * loss_train.item()
        total += bs
        loss_train.backward()
        optimizer.step()
    logger.info("train loss in this epoch %f", loss / total)
    if (epoch + 1) % args["check_point"] == 0:
        logger.info("epoch %d, checkpoint!", epoch)
        best_thr = evaluate(epoch, valid_loader, return_best_thr=True, log_desc='valid_')
        evaluate(epoch, test_loader, thr=best_thr, log_desc='test_')

# for epoch in range(args["epochs"]):
#     train(epoch, train_loader, valid_loader, test_loader)

best_thr = evaluate(0, valid_loader, return_best_thr=True, log_desc='valid_')
evaluate(0, test_loader, thr=best_thr, log_desc='test_')
"""
2022-09-19 18:36:01,818 train loss in this epoch 0.304292
2022-09-19 18:45:58,481 train loss in this epoch 0.096567
2022-09-19 18:45:58,484 epoch 1, checkpoint!
2022-09-19 18:47:17,276 valid_loss: 0.0495 AUC: 0.9972 Prec: 0.9972 Rec: 1.0000 F1: 0.9986
/tmp/ipykernel_30944/2709401001.py:55: RuntimeWarning: invalid value encountered in true_divide
  f1s = 2 * precs * recs / (precs + recs)
2022-09-19 18:47:17,301 best threshold=-0.040288, f1=0.9990
2022-09-19 18:48:35,552 using threshold -0.0403
2022-09-19 18:48:35,601 test_loss: 0.0501 AUC: 0.9976 Prec: 0.9975 Rec: 1.0000 F1: 0.9988
"""

2022-09-22 12:42:28,618 torch.Size([256, 50, 2])
2022-09-22 12:42:28,620 torch.Size([256, 2])
2022-09-22 12:42:29,528 torch.Size([256, 50, 2])
2022-09-22 12:42:29,529 torch.Size([256, 2])
2022-09-22 12:42:30,180 torch.Size([256, 50, 2])
2022-09-22 12:42:30,181 torch.Size([256, 2])


KeyboardInterrupt: 