In [5]:
import model
import torch
from torch import nn

In [17]:
# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
GENERATE_EVERY  = 100
NUM_TOKENS = 16 + 2
ENC_SEQ_LEN = 32
DEC_SEQ_LEN = 32

def cycle():
    while True:
        prefix = torch.ones((BATCH_SIZE, 1)).long()
        src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long()
        tgt = torch.cat((prefix, src), 1)
        src_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool()
        tgt_mask = torch.ones(BATCH_SIZE, tgt.shape[1] - 1).bool()
        yield (src, tgt, src_mask, tgt_mask)
        
def culc_loss(loss_func, inputs, targets):
    """
    損失関数の計算
    args:
        - loss_func : 損失関数(交差エントロピー)
        - input (B x len x d): 入力データ
        - target (B x len): ターゲットデータ
    
    文章ごとに平均をとって、バッチごとに平均をとる
    pytorchの交差エントロピー使わない方が収束早いし、lossも小さくなる。。。
    どういうことだ？
    計算結果が微妙に違う気がするし、nn.CrossEntropyが所望の計算をしてない可能性ある？
    """
    B, l, d = inputs.shape
    _loss = 0
    loss = 0
    for i in range(B):
        loss += loss_func(inputs[i], targets[i])# 内部的に文章平均
#         _loss += cross_ent(inputs[i], targets[i])
#     _loss /= B# バッチ平均
    loss /= B
    return loss

In [15]:
# next(cycle())

In [None]:
DIMENTION = 512
HEAD = 8
DEPTH_enc = 1
DEPTH_dec = 3

transformer = model.Model("cpu", DIMENTION, NUM_TOKENS, 0.0, DEPTH_enc, DEPTH_dec, HEAD, HEAD)
criterion = nn.CrossEntropyLoss()
# optimizer

optim = torch.optim.Adam(transformer.parameters(), lr=LEARNING_RATE)

# training

for i in range(NUM_BATCHES):
    transformer.train()
    src, tgt, src_mask, tgt_mask = next(cycle())
    print("src : {}".format(src.shape))
    print("tgt : {}".format(tgt.shape))
    
    out = transformer(src, tgt)
    print(out.shape)
    loss = culc_loss(criterion, out[:,:-1,:], src)
    loss.backward()
    print(loss.item())

    optim.step()
    optim.zero_grad()

    if i % GENERATE_EVERY == 0:
        transformer.eval()
        src, _, src_mask, _ = next(cycle())
        src, src_mask = src[0:1], src_mask[0:1]
        
        sample = transformer.generate(src)
#         import pdb; pdb.set_trace()
#         incorrects = (src != sample).abs().sum()
        incorrects = torch.sum(src != sample)

        print(f"input:  ", src)
        print(f"predicted output:  ", sample)
        print(f"incorrects: {incorrects}")

src : torch.Size([16, 32])
tgt : torch.Size([16, 33])
torch.Size([16, 33, 18])
3.031921625137329
input:   tensor([[ 4, 13, 11, 14, 17,  9, 15,  2,  6,  4, 12, 16,  3,  7, 14, 15, 14, 11,
          7, 13,  2, 10,  8, 15,  6,  7,  7,  6,  3,  4,  5, 12]])
predicted output:   tensor([[ 7, 10, 12,  3,  2, 12,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
          9,  7,  4, 13,  2, 15,  8,  8,  8,  8,  8,  8,  8,  8]])
incorrects: 30
src : torch.Size([16, 32])
tgt : torch.Size([16, 33])
torch.Size([16, 33, 18])
2.9528870582580566
src : torch.Size([16, 32])
tgt : torch.Size([16, 33])
torch.Size([16, 33, 18])
2.883139133453369
src : torch.Size([16, 32])
tgt : torch.Size([16, 33])
torch.Size([16, 33, 18])
2.8683905601501465
src : torch.Size([16, 32])
tgt : torch.Size([16, 33])
torch.Size([16, 33, 18])
2.8277370929718018
src : torch.Size([16, 32])
tgt : torch.Size([16, 33])
torch.Size([16, 33, 18])
2.820688486099243
src : torch.Size([16, 32])
tgt : torch.Size([16, 33])
torch.Size([16, 33, 1