In [1]:
import sys
import torch
import numpy as np
from typing import List, Optional
#
sys.path.insert(1, "lib")
#
from clozenet import Pretraining, Finetuning

#### [A01] Test pretrain model

In [2]:
num_total_char = 200
char_dim = 128
width_list = [1,2,3,4,5,6]
num_filters_list = [128,256,384,512,512,512]
att_output_dim = 1024
ffw_dim = 4096
num_heads = 16
num_block = 12
num_heads_last = 32
vocab_size = 1000 ### in paper 1.000.000
cutoffs = [64,220] ### in paper [64.000, 220.000]
### we have 13 sentences, each sentence has 11 words
### first 6 words in forward, last 4 words in backward, 6th word for prediction (start from 0)
### shape (batch_size, seq_size, max_len+2) = (13,6,7)
ts10A_input_fw = torch.randint(0,num_total_char,(13,6,7))
### shape (batch_size, seq_size, max_len+2) = (13,4,7)
ts20A_input_bw = torch.randint(0,num_total_char,(13,4,7))
start_pos_fw = 0
start_pos_bw = 7
### shape (batch_size)
target=torch.randint(0,vocab_size,(13,))
###
pretrain_model = Pretraining(num_total_char, char_dim, width_list, num_filters_list,
    att_output_dim, ffw_dim, num_heads, num_block, num_heads_last,
    vocab_size, cutoffs)
#
_, prediction = pretrain_model(ts10A_input_fw, ts20A_input_bw, start_pos_fw, start_pos_bw)
loss, _ = pretrain_model(ts10A_input_fw, ts20A_input_bw, start_pos_fw, start_pos_bw, target)
print(prediction.shape)
print(loss, loss.shape)

torch.Size([13])
tensor(10.3762, grad_fn=<MeanBackward0>) torch.Size([])


#### [A05] Test finetuning model

In [3]:
num_total_char = 200
char_dim = 128
width_list = [1,2,3,4,5,6]
num_filters_list = [128,256,384,512,512,512]
att_output_dim = 1024
ffw_dim = 4096
num_heads = 16
num_block = 12
lstm_hid_dim = 4096
lstm_proj_dim = 512
num_tags = 5
pretrain_twotower = pretrain_model.two_tower
###
### we have 13 sentences, each sentence has 11 words
### shape (batch_size, seq_size, max_len+2) = (13,11,7)
ts10A_input = torch.randint(0,num_total_char,(13,11,7))
target=torch.randint(0,num_tags,(13,11,))
###
finetuning_model = Finetuning(num_total_char, char_dim, width_list, num_filters_list,
    att_output_dim, ffw_dim, num_heads, num_block,
    lstm_hid_dim, lstm_proj_dim, num_tags, pretrain_twotower)

In [4]:
_, prediction = finetuning_model(ts10A_input)
loss, _ = finetuning_model(ts10A_input,target)
print(prediction.shape)
print(loss, loss.shape)

torch.Size([13, 11])
tensor(230.2599, grad_fn=<SumBackward0>) torch.Size([])
