In [1]:
import os
import sys
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader
from collections import Counter
from scipy.spatial.distance import jensenshannon
if '..' not in os.sys.path:os.sys.path.append('..')
from model import EncoreSequential, WeightNLLLoss
from utils.dist import get_probability_distributions_from_sequence,compute_ngram_distribution, compute_js_distance
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")

In [2]:
dataset = 'icbc'
data_dir = f'/mnt/ssd1/hsj/encore/{dataset}'
size_cdf_file = os.path.join(data_dir, f'{dataset}_cdf', f'{dataset}_size_cdf_coarse.csv')
interval_cdf_file = os.path.join(data_dir, f'{dataset}_cdf', f'{dataset}_interval_cdf_coarse.csv')
size_cdf = pd.read_csv(size_cdf_file)
interval_cdf = pd.read_csv(interval_cdf_file)
n_size, n_interval = len(size_cdf), len(interval_cdf)
start_token = n_size * n_interval
print(f'n_size: {n_size}, n_interval: {n_interval}, start_token: {start_token}')

n_size: 30, n_interval: 30, start_token: 900


In [3]:
def get_trainset(pair_index_file, block_size, n_size, n_interval, min_samples, device):
    pair_index = []
    with open(pair_index_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            pair_index.append([int(x) for x in line.strip().split(',')])
    
    trainset = []
    seqs = []
    size_probs, interval_probs = [], []
    for i, seq in tqdm(enumerate(pair_index[0:1000])):
        seq_trainset = []
        seq = np.concatenate((seq[:-1], seq[:1]))
        seqs.append(seq)
        size_prob, interval_prob = get_probability_distributions_from_sequence(seq, n_size, n_interval)
        size_probs.append(size_prob)
        interval_probs.append(interval_prob)
        size_seq = seq // n_interval
        permuted_seq = np.random.permutation(size_seq)
        size_pairs = compute_ngram_distribution(size_seq, 2)
        permuted_pairs = compute_ngram_distribution(permuted_seq, 2)
        jsd = compute_js_distance(size_pairs, permuted_pairs) ** 2
        jsd = jsd if jsd > 0.1 else 0.1
        weight = jsd * np.log10(len(seq))

        num_samples = max(min_samples, len(seq) - block_size)
        seq = np.append([start_token], seq)
        for i in range(num_samples):
            index = i % (len(seq) - block_size)
            sequence_block = seq[index:index+block_size]
            target_sequence = seq[index+1:index+block_size+1]
            probs = np.concatenate((size_prob, interval_prob))

            seq_trainset.append([
                torch.from_numpy(sequence_block).long(),
                torch.from_numpy(target_sequence).long(),
                torch.from_numpy(probs).float(),
                torch.tensor([weight]).float()
            ])
        trainset.extend(seq_trainset)
    size_probs, interval_probs = np.array(size_probs), np.array(interval_probs)
    return trainset, size_probs, interval_probs, seqs

block_size = 16
min_samples = 100
pair_index_file = os.path.join(data_dir, f'{dataset}_pair_index.txt')
trainset, size_probs, interval_probs, seqs = get_trainset(pair_index_file, block_size, n_size, n_interval, min_samples, device)
print(f'Trainset size: {len(trainset)}')

1000it [00:05, 170.07it/s]

Trainset size: 164246





In [67]:
# def get_dataloader(trainset, subset_size, batch_size, seed):
#     np.random.seed(seed)  
#     subset_size = min(subset_size, len(trainset))  
#     ran_index = np.random.choice(len(trainset), subset_size, replace=False)
#     subset_trainset = [trainset[i] for i in ran_index]
#     return DataLoader(subset_trainset, batch_size=batch_size, shuffle=True)

# subset_size = 64000
# batch_size = 256
# seed = 0
# train_loader = get_dataloader(trainset, subset_size, batch_size, seed)
# seq_tensor, target_tensor, prob_tensor, weight_tensor = next(iter(train_loader))
# seq_tensor.shape, target_tensor.shape, prob_tensor.shape, weight_tensor.shape

In [4]:
batch_size = 256
train_loader = DataLoader(trainset, batch_size=batch_size, pin_memory=True)

In [5]:
def get_model_params(model: nn.Module) -> tuple:
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

model = EncoreSequential(n_size, n_interval, 512, block_size, [128, 256, 512], device).to(device)
get_model_params(model)

(37091844, 37091844)

In [6]:
seq_tensor, target_tensor, prob_tensor, weight_tensor = next(iter(train_loader))
seq_tensor, target_tensor, prob_tensor, weight_tensor = seq_tensor.to(device), target_tensor.to(device), prob_tensor.to(device), weight_tensor.to(device)
output = model(seq_tensor, prob_tensor)

In [7]:
model_dir = f'/mnt/ssd1/hsj/encore/{dataset}/model/transformer-01-09/'
model_path = os.path.join(model_dir, f'encore_transformer_200.pt')
if os.path.exists(model_path):
    state_dict = torch.load(model_path)
    # 如果需要移除 `module.` 前缀
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[len('module.'):]] = v
        else:
            new_state_dict[k] = v
    model.load_state_dict(new_state_dict)
    print(f'Loaded model from {model_path}')

Loaded model from /mnt/ssd1/hsj/encore/icbc/model/transformer-01-09/encore_transformer_200.pt


In [19]:
def train_epoch(model, train_loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for seq_tensor, target_tensor, prob_tensor, weight_tensor in train_loader:
        seq_tensor, target_tensor, prob_tensor, weight_tensor = seq_tensor.to(device), target_tensor.to(device), prob_tensor.to(device), weight_tensor.to(device)
        optimizer.zero_grad()
        output = model(seq_tensor, prob_tensor)
        loss = criterion(output, target_tensor, weight_tensor)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * len(seq_tensor)
    return total_loss / len(train_loader.dataset)

In [20]:
criterion = WeightNLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
s_time = time.time()
for epoch in range(1, 11):
    # train_loader = get_dataloader(trainset, subset_size, batch_size, epoch)
    loss = train_epoch(model, train_loader, optimizer, criterion)
    print(f'Epoch {epoch}, Loss: {loss:.4f}, Time: {time.time()-s_time:.2f}')

Epoch 1, Loss: 1.2451, Time: 55.51
Epoch 2, Loss: 1.1301, Time: 112.35
Epoch 3, Loss: 1.0836, Time: 169.37
Epoch 4, Loss: 1.0699, Time: 227.54
Epoch 5, Loss: 1.0333, Time: 287.42
Epoch 6, Loss: 1.0154, Time: 348.34


KeyboardInterrupt: 

In [36]:
def check_seq_gen(seq_gen, size_prob, interval_prob):
    size_gen = seq_gen // n_interval
    interval_gen = seq_gen % n_interval
    return np.min(size_prob[size_gen]) > 0.0 and np.min(interval_prob[interval_gen]) > 0.0


def generate_sequence(model, size_probs, interval_probs, batch_size, block_size, max_length):
    num_smaples = len(size_probs)
    gen_seqs = []
    for i_batch in range(0, num_smaples, batch_size):
        batch_size_probs = size_probs[i_batch:i_batch+batch_size]
        batch_interval_probs = interval_probs[i_batch:i_batch+batch_size]
        batch_probs = np.concatenate((batch_size_probs, batch_interval_probs), axis=1, dtype=np.float32)
        batch_gen_seqs = [[] for _ in range(len(batch_probs))]
        while np.min([len(seq) for seq in batch_gen_seqs]) < max_length:
            current_start_tokens = np.array([start_token] * len(batch_probs))
            seq_gens = model.generate(batch_probs, current_start_tokens, block_size)
            gen_check = [check_seq_gen(seq_gens[i], batch_size_probs[i], batch_interval_probs[i]) for i in range(len(seq_gens))]
            for i in range(len(seq_gens)):
                if gen_check[i]:
                    batch_gen_seqs[i].extend(seq_gens[i])
                    current_start_tokens[i] = seq_gens[i][-1]
            print(i_batch, np.min([len(seq) for seq in batch_gen_seqs]), end='\r')
        batch_gen_seqs = [seq[:max_length] for seq in batch_gen_seqs]
        gen_seqs.extend(batch_gen_seqs)
        print()
        print(len(gen_seqs))
    return gen_seqs

gen_seqs = generate_sequence(model, size_probs, interval_probs, batch_size, block_size, 1000)

0 1008
256
256 1008
512
512 1008
768
768 1008
1000


In [14]:
# def check_seq_gen(seq_gen, size_prob, interval_prob):
#     size_gen = seq_gen // n_interval
#     interval_gen = seq_gen % n_interval
#     return np.min(size_prob[size_gen]) > 0.0 and np.min(interval_prob[interval_gen]) > 0.0


# def generate_seq(model, size_prob, interval_prob, block_size, max_length):
#     model.eval()
#     current_start_token = start_token
#     seq = []
#     prob = np.concatenate((size_prob, interval_prob), dtype=np.float32)
#     with torch.no_grad():
#         while len(seq) < max_length:
#             seq_gen = model.generate(prob, current_start_token, block_size)
#             if check_seq_gen(seq_gen, size_prob, interval_prob):
#                 seq.extend(seq_gen)
#                 current_start_token = seq_gen[-1]
#             # print(f'Generated sequence length: {len(seq)}', end='\r')
#         # print()
#     return seq

In [15]:
# gens = []
# for i in tqdm(range(len(size_probs))):
#     seq = generate_seq(model, size_probs[i], interval_probs[i], block_size, 1000)
#     # print(f'Sequence {i} length: {len(seq)}')
#     gens.append(seq)

  1%|          | 1/100 [00:06<11:23,  6.90s/it]




  1%|          | 1/100 [00:11<19:46, 11.99s/it]


KeyboardInterrupt: 

In [37]:
def test_sequence_gen(seq_ori, seq_gen, n_interval):
    seq_ori, seq_gen = np.array(seq_ori), np.array(seq_gen)
    size_seq, interval_seq = seq_ori // n_interval, seq_ori % n_interval
    size_gen, interval_gen = seq_gen // n_interval, seq_gen % n_interval
    ori_size_interval = list(zip(size_seq, interval_seq))
    gen_size_interval = list(zip(size_gen, interval_gen))
    jsds = {}
    ori_size_interval_dist = Counter(ori_size_interval)
    gen_size_interval_dist  = Counter(gen_size_interval)
    jsds['size_interval'] = compute_js_distance(ori_size_interval_dist, gen_size_interval_dist) ** 2
    for n in [2, 3, 4]:
        n_gram_ori = compute_ngram_distribution(size_seq, n)
        n_gram_gen = compute_ngram_distribution(size_gen, n)
        jsds['size_{}'.format(n)] = compute_js_distance(n_gram_ori, n_gram_gen) ** 2
    return jsds

jsds = []
for i in range(1000):
    jsd = test_sequence_gen(seqs[i], gen_seqs[i], n_interval)
    jsds.append(jsd)
    print(f'Sequence {i} JSD: {jsd}')

Sequence 0 JSD: {'size_interval': 0.005819697701083214, 'size_2': 0.0011287798199279425, 'size_3': 0.0017826977831219394, 'size_4': 0.0024258385536060825}
Sequence 1 JSD: {'size_interval': 0.007076993383931781, 'size_2': 0.0004992656805095407, 'size_3': 0.000649318340534293, 'size_4': 0.0024743483879011126}
Sequence 2 JSD: {'size_interval': 0.003328599991712988, 'size_2': 0.001379992358299539, 'size_3': 0.002318968011033522, 'size_4': 0.00437766585817431}
Sequence 3 JSD: {'size_interval': 0.006911237473951004, 'size_2': 0.0010842493096944346, 'size_3': 0.0021685898784508887, 'size_4': 0.005086693847850188}
Sequence 4 JSD: {'size_interval': 0.009112488439404407, 'size_2': 0.01776009701419041, 'size_3': 0.031482549367689475, 'size_4': 0.045643452738644286}
Sequence 5 JSD: {'size_interval': 0.0010373067681914243, 'size_2': 0.0, 'size_3': 0.0, 'size_4': 0.0}
Sequence 6 JSD: {'size_interval': 0.22329010865722043, 'size_2': 0.2229754452968182, 'size_3': 0.36580452383520856, 'size_4': 0.39601

In [38]:
jsds = pd.DataFrame(jsds)

In [39]:
for key in jsds.keys():
    print(f'{key}: mean: {jsds[key].mean()}, median: {jsds[key].median()}, p90: {jsds[key].quantile(0.9)}, p95: {jsds[key].quantile(0.95)}, p99: {jsds[key].quantile(0.99)}')

size_interval: mean: 0.04343357499528365, median: 0.008694420875225635, p90: 0.14796114656196227, p95: 0.26659155301721077, p99: 0.3826285277537146
size_2: mean: 0.03919982898758229, median: 0.0024793791911609875, p90: 0.13660048975033376, p95: 0.28287794586456505, p99: 0.3908552737087098
size_3: mean: 0.05853032775115767, median: 0.004932467506191716, p90: 0.24744079573731118, p95: 0.3781692646006074, p99: 0.46410718255855976
size_4: mean: 0.07449789790997467, median: 0.008779793814995239, p90: 0.3204629501221032, p95: 0.41636908257236155, p99: 0.48042925994279645


In [35]:
for key in jsds.keys():
    print(f'{key}: mean: {jsds[key].mean()}, median: {jsds[key].median()}, p90: {jsds[key].quantile(0.9)}, p95: {jsds[key].quantile(0.95)}, p99: {jsds[key].quantile(0.99)}')

size_interval: mean: 0.0548317159109112, median: 0.01878444758433427, p90: 0.16482003526940397, p95: 0.268804331562431, p99: 0.38460681733337254
size_2: mean: 0.04462036085837829, median: 0.005882325082381739, p90: 0.16063841680356572, p95: 0.28587929688998953, p99: 0.39784901042032894
size_3: mean: 0.06880187720820077, median: 0.011257682879155996, p90: 0.27308699170430983, p95: 0.3810706475152146, p99: 0.4640546436905572
size_4: mean: 0.08952715708343681, median: 0.018529527728655043, p90: 0.35796939462134747, p95: 0.41938744520103655, p99: 0.4972871062818458


In [29]:
np.argsort(jsds['size_4'].values), np.sort(jsds['size_4'].values)

(array([99, 34, 44, 22, 45, 46, 54, 55, 28, 58, 30,  5,  9, 92, 68, 90, 89,
        66, 57, 23, 93, 74, 67, 36, 94, 52, 56, 15, 38, 18,  0, 85, 17, 98,
        40,  4,  3, 33, 41, 80,  1, 61,  2, 81, 87, 95, 64,  7, 86, 51, 37,
        79, 26, 50, 29, 77, 60, 19, 91, 59, 27, 12, 73, 71, 25, 35, 72, 63,
        65, 78, 76, 32, 16, 11, 48, 70, 62, 88, 24, 49, 97, 69, 10, 13, 83,
        96, 42, 47, 84, 82, 31, 43, 75, 21,  8, 39,  6, 20, 14, 53]),
 array([0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 1.37348465e-05, 3.71760920e-05, 1.77514450e-04,
        3.33676360e-04, 3.79575737e-04, 3.84660777e-04, 4.91131936e-04,
        6.24186371e-04, 8.52671849e-04, 9.01621571e-04, 9.18920523e-04,
        1.06582061e-03, 1.47500209e-03, 1.6255

In [36]:
size_seq = seqs[8] // n_interval
size_gen = np.array(gens[8]) // n_interval
n_gram_ori = compute_ngram_distribution(size_seq, 2)
n_gram_gen = compute_ngram_distribution(size_gen, 2)

In [37]:
size_seq, size_gen[0:len(size_seq)]

(array([18, 16, 10, 10,  9, 11, 10, 11,  7, 11, 12,  9, 11, 12, 14, 19, 16,
        13, 10,  9, 20, 10, 13, 14, 11, 20, 11, 20, 17, 11, 10, 18, 11, 12,
        14, 11, 18, 10, 20, 11, 10, 20, 21, 17, 10, 11, 11, 11, 11, 10, 11,
        11, 11, 17, 10, 15, 11, 11, 15, 10, 10,  9,  9,  9, 18]),
 array([18, 16, 10, 10,  9, 11, 10, 11,  7, 11, 12,  9, 11, 12, 14, 19, 16,
        13, 11, 11, 16, 13, 14, 11, 20, 11, 20, 17, 11, 10, 18, 11, 10, 11,
        11, 11, 17, 10, 15, 11, 11, 15, 10, 10,  9, 17, 11, 15, 19, 11, 17,
        20, 15, 20, 11, 11, 11, 20, 11, 11, 11, 10, 11, 11, 11]))