In [2]:
# Import necessary libraries
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from model.sasrec import SASRecModel
from trainers import Trainer
from utils import EarlyStopping, check_path, set_seed, set_logger
from dataset import get_seq_dic, get_dataloder, get_rating_matrix

# Set up arguments
class Args:
    data_dir = "./data/"
    output_dir = "output/"
    data_name = "Beauty"
    do_eval = False
    load_model = None
    train_name = "sasrec_model"
    num_items = 10
    num_users = 10
    lr = 0.001
    batch_size = 256
    epochs = 10
    no_cuda = False
    log_freq = 1
    patience = 10
    num_workers = 0  # Set num_workers to 0 to avoid BrokenPipeError on Windows
    seed = 42
    weight_decay = 0.0
    adam_beta1 = 0.9
    adam_beta2 = 0.999
    gpu_id = "0"
    variance = 5
    model_type = 'sasrec'
    max_seq_length = 50
    hidden_size = 64
    num_hidden_layers = 2
    hidden_act = "gelu"
    num_attention_heads = 2
    attention_probs_dropout_prob = 0.5
    hidden_dropout_prob = 0.5
    initializer_range = 0.02

args = Args()

if __name__ == "__main__":
    # Initialize logger
    log_path = os.path.join(args.output_dir, args.train_name + '.log')
    logger = set_logger(log_path)

    # Set seed for reproducibility
    set_seed(args.seed)

    # Create output directory if not exists
    check_path(args.output_dir)

    # Set CUDA environment
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    args.cuda_condition = torch.cuda.is_available() and not args.no_cuda

    # Load data
    seq_dic, max_item, num_users = get_seq_dic(args)
    args.item_size = max_item + 1
    args.num_users = num_users + 1

    # Prepare checkpoint paths
    args.checkpoint_path = os.path.join(args.output_dir, args.train_name + '.pt')
    args.same_target_path = os.path.join(args.data_dir, args.data_name+'_same_target.npy')

    # Load dataloaders
    train_dataloader, eval_dataloader, test_dataloader = get_dataloder(args, seq_dic)

    # Initialize and log model
    logger.info(str(args))
    model = SASRecModel(args=args)
    logger.info(model)

    # Initialize trainer
    trainer = Trainer(model, train_dataloader, eval_dataloader, test_dataloader, args, logger)

    # Generate rating matrices for evaluation
    args.valid_rating_matrix, args.test_rating_matrix = get_rating_matrix(args.data_name, seq_dic, max_item)

    # Training and evaluation
    if args.do_eval:
        if args.load_model is None:
            logger.info(f"No model input!")
            exit(0)
        else:
            args.checkpoint_path = os.path.join(args.output_dir, args.load_model + '.pt')
            trainer.load(args.checkpoint_path)
            logger.info(f"Load model from {args.checkpoint_path} for test!")
            scores, result_info = trainer.test(0)
    else:
        early_stopping = EarlyStopping(args.checkpoint_path, logger=logger, patience=args.patience, verbose=True)
        for epoch in range(args.epochs):
            trainer.train(epoch)
            scores, _ = trainer.valid(epoch)
            # evaluate on MRR
            early_stopping(np.array(scores[-1:]), trainer.model)
            if early_stopping.early_stop:
                logger.info("Early stopping")
                break

        logger.info("---------------Test Score---------------")
        trainer.model.load_state_dict(torch.load(args.checkpoint_path))
        scores, result_info = trainer.test(0)

    logger.info(args.train_name)
    logger.info(result_info)

2024-05-29 15:30:52,765 - <__main__.Args object at 0x1041e21f0>
2024-05-29 15:30:52,765 - <__main__.Args object at 0x1041e21f0>
2024-05-29 15:30:52,787 - SASRecModel(
  (item_embeddings): Embedding(12102, 64, padding_idx=0)
  (position_embeddings): Embedding(50, 64)
  (LayerNorm): LayerNorm()
  (dropout): Dropout(p=0.5, inplace=False)
  (item_encoder): TransformerEncoder(
    (blocks): ModuleList(
      (0-1): 2 x TransformerBlock(
        (layer): MultiHeadAttention(
          (query): Linear(in_features=64, out_features=64, bias=True)
          (key): Linear(in_features=64, out_features=64, bias=True)
          (value): Linear(in_features=64, out_features=64, bias=True)
          (softmax): Softmax(dim=-1)
          (attn_dropout): Dropout(p=0.5, inplace=False)
          (dense): Linear(in_features=64, out_features=64, bias=True)
          (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
          (out_dropout): Dropout(p=0.5, inplace=False)
        )
        (feed_f

In [3]:
num_batches_to_show = 3  # 확인하고자 하는 배치의 수

for i, batch in enumerate(test_dataloader):
    if i >= num_batches_to_show:
        break  # 지정된 수의 배치를 확인한 후 종료
    indices, input_ids, answers, _, _ = batch
    print(f"Batch {i+1}:")
    print("Indices:", indices)
    print("Input IDs:", input_ids)
    print("Answers:", answers)
    print("-" * 50)

Batch 1:
Indices: tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168, 169, 170, 171, 172, 173, 174, 175

In [4]:
def get_attention_weights(model, input_ids):
    model.eval()
    with torch.no_grad():
        extended_attention_mask = model.get_attention_mask(input_ids)
        sequence_emb = model.add_position_embedding(input_ids)
        item_encoded_layers = model.item_encoder(sequence_emb, extended_attention_mask, output_all_encoded_layers=True)
        attention_weights = item_encoded_layers[-1]  # Extract the attention weights from the last layer
        #attention_probs = torch.nn.functional.softmax(attention_weights, dim=-1)
    return attention_weights

# Example usage with a batch from the test dataloader
batch = next(iter(test_dataloader))
input_ids = batch[1]
attention_weights = get_attention_weights(model, input_ids)
mean_attention_weights = torch.mean(attention_weights, dim=-1)
print(mean_attention_weights.size())
attention_probs = torch.nn.functional.softmax(mean_attention_weights, dim=-1)
attention_probs[3]

torch.Size([256, 50])


tensor([0.0197, 0.0198, 0.0199, 0.0196, 0.0199, 0.0197, 0.0200, 0.0199, 0.0195,
        0.0201, 0.0205, 0.0199, 0.0198, 0.0207, 0.0205, 0.0195, 0.0213, 0.0209,
        0.0200, 0.0204, 0.0205, 0.0208, 0.0204, 0.0203, 0.0196, 0.0209, 0.0209,
        0.0202, 0.0205, 0.0200, 0.0201, 0.0201, 0.0208, 0.0196, 0.0198, 0.0197,
        0.0195, 0.0185, 0.0192, 0.0197, 0.0198, 0.0196, 0.0197, 0.0196, 0.0201,
        0.0195, 0.0188, 0.0199, 0.0204, 0.0201])