In [1]:
from main import settings_to_fn
import argparse
from omegaconf import OmegaConf
from models import init_model
import sys
import torch
from dataset import NoteTupleDataset
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [2]:
config_file = "configs/note_tuple/baseline.yaml"
sys.argv = ['argparse_example.py']
parser = argparse.ArgumentParser()
parser.add_argument("--size", type=str, default="small")
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

In [3]:
cfg_def = OmegaConf.load("configs/small.yaml")
cfg_setting = OmegaConf.load(config_file)
cfg = OmegaConf.merge(cfg_def, cfg_setting)
cfg.data.batch_size = 1
fn = settings_to_fn(cfg, args)

In [4]:
model, _ = init_model(cfg.model, cfg.music_rep, "cpu", False)
model = model.to(device)

Number of trainable parameters: 3533799


In [5]:
class SimpleHook:
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn, with_kwargs=True)

    def hook_fn(self, module, args, kwargs, output):
        self.args = args
        self.kwargs = kwargs
        self.output = output

In [6]:
hooks = [SimpleHook(model.net.layers[i].attn.attn_hook) for i in range(len(model.net.layers))]

In [7]:
dataset = NoteTupleDataset(cfg.data.data_root, cfg.data.data_src, cfg.data.train_file, False)
loader = DataLoader(
    dataset,
    batch_size=cfg.data.batch_size,
    shuffle=True,
    num_workers=cfg.data.num_workers,
)
with torch.no_grad():
    for batch in loader:
        loss = model(batch.to(device))
        break

                                                     

In [11]:
# batch_size, head, seq_len, seq_len
attn = hooks[0].kwargs["attn"][0]
attn = attn.mean(dim=0)

In [14]:
attn[-10:, -10:]

tensor([[0.0009, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0009, 0.0010, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0009, 0.0011, 0.0010, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0009, 0.0010, 0.0010, 0.0010, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0009, 0.0011, 0.0010, 0.0009, 0.0010, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0008, 0.0011, 0.0010, 0.0010, 0.0010, 0.0009, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0009, 0.0010, 0.0010, 0.0009, 0.0010, 0.0009, 0.0009, 0.0000, 0.0000,
         0.0000],
        [0.0009, 0.0011, 0.0011, 0.0009, 0.0010, 0.0009, 0.0010, 0.0010, 0.0000,
         0.0000],
        [0.0009, 0.0011, 0.0010, 0.0010, 0.0010, 0.0009, 0.0010, 0.0010, 0.0009,
         0.0000],
        [0.0009, 0.0010, 0.0010, 0.0009, 0.0009, 0.0008, 0.0009, 0.0010, 0.0009,
         0.0009]])

In [28]:
batch[0, :20, 1: 3]

tensor([[-100, -100],
        [   1,    2],
        [   1,    2],
        [   1,    2],
        [   1,    8],
        [   1,   14],
        [   1,   14],
        [   1,   14],
        [   1,   14],
        [   1,   14],
        [   1,   20],
        [   1,   26],
        [   1,   26],
        [   1,   32],
        [   1,   32],
        [   1,   35],
        [   1,   41],
        [   1,   44],
        [   1,   44],
        [   1,   44]])

In [36]:
time = torch.where(batch[0, :-1, 1] == -100, -100, (batch[0, :-1, 1] - 1) * 48 + batch[0, :-1, 2])
idx_not_100 = torch.where(time != -100)[0]
time_rel = time.view(-1, 1) - time.view(1, -1)

In [38]:
attn_time_sum = [0 for _ in range(768)]  # 0 ~ 767
attn_time_num = [0 for _ in range(768)]  # 0 ~ 767

In [39]:
num_valid_token = len(idx_not_100)
for i in range(num_valid_token):
    for j in range(i + 1):
        idx_query = idx_not_100[i]
        idx_key = idx_not_100[j]
        attn_time_sum[time_rel[idx_query, idx_key]] += attn[idx_query, idx_key]
        attn_time_num[time_rel[idx_query, idx_key]] += 1

attn_time_avg = [attn_time_sum[i] / attn_time_num[i] if attn_time_num[i] != 0 else 0 for i in range(768)]

In [42]:
attn_time_avg[:50]

[tensor(0.0196),
 0,
 0,
 tensor(0.0173),
 0,
 0,
 tensor(0.0143),
 0,
 0,
 tensor(0.0184),
 0,
 0,
 tensor(0.0160),
 0,
 0,
 tensor(0.0173),
 0,
 0,
 tensor(0.0127),
 0,
 0,
 tensor(0.0182),
 0,
 0,
 tensor(0.0116),
 0,
 0,
 tensor(0.0198),
 0,
 0,
 tensor(0.0125),
 0,
 0,
 tensor(0.0170),
 0,
 0,
 tensor(0.0111),
 0,
 0,
 tensor(0.0159),
 0,
 0,
 tensor(0.0112),
 0,
 0,
 tensor(0.0147),
 0,
 0,
 tensor(0.0106),
 0]