In [42]:
"""
Sample from a trained model
"""
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from model import GPTConfig, GPT
import json

# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out-stock' # ignored if init_from is not 'resume'
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 10 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------


In [43]:
import pandas as pd

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)


model.eval()
model.to(device)


datadir = os.path.join('data', 'stock')

# meta数据
meta = {}
with open(os.path.join(datadir, 'meta.pkl'), 'r') as f:
    meta = json.load(f)
    meta_vocab_size = meta['vocab_size']
    meta_vocab_size = 4096
def decode(id):
    return meta['itos'][str(id)]
def encode(s):
    return [meta['stoi'][c] for c in s]

pd_train_data = pd.read_csv(os.path.join(datadir, 'train.csv')).iloc[1:,:meta_vocab_size+1]
pd_val_data = pd.read_csv(os.path.join(datadir, 'val.csv')).iloc[:,:meta_vocab_size+1]
pd_val_data



config is GPTConfig(block_size=5, vocab_size=4096, n_layer=8, n_head=8, n_embd=512, dropout=0, bias=False)
number of parameters: 27.27M


Unnamed: 0,trade_date,000001.SZ,000002.SZ,000004.SZ,000005.SZ,000006.SZ,000007.SZ,000008.SZ,000009.SZ,000010.SZ,...,603320.SH,603321.SH,603322.SH,603323.SH,603324.SH,603325.SH,603326.SH,603327.SH,603328.SH,603329.SH
0,20240206,1.032,1.0359,0.9051,0.9506,1.042,0.995,1.0412,1.1,0.961,...,0.9091,0.9156,1.0424,1.0339,1.0262,1.0645,0.9214,1.0366,1.0486,0.9616
1,20240207,0.9928,1.0011,0.9859,0.9481,1.0161,0.9799,0.9901,1.0991,0.9036,...,0.9,0.9136,0.95,1.0211,0.9627,0.9671,0.9197,0.9729,1.0232,1.0999
2,20240208,1.0062,1.0567,1.0559,0.9452,1.0714,1.0308,1.035,0.9942,1.0955,...,1.1,1.1008,1.0954,0.9954,1.0999,1.0887,1.0727,1.0967,1.0923,1.0387
3,20240219,1.0145,0.9781,1.08,0.9565,0.9852,1.0149,1.0145,0.995,1.0359,...,1.0999,1.0534,1.1001,1.0138,1.044,1.0,1.061,1.0639,1.0096,0.9962
4,20240220,1.001,1.0061,1.0144,1.0455,1.0025,1.049,0.9952,1.0,1.0396,...,1.0306,1.0127,1.0618,1.0045,1.0289,1.0154,1.0176,1.0062,1.0047,1.013
5,20240221,1.0998,1.0313,1.0218,1.0435,1.01,1.0023,1.0096,0.9858,1.0429,...,1.0248,1.0233,0.9799,1.0226,1.0052,1.0249,1.0565,1.001,1.0,1.0098
6,20240222,1.0093,0.9941,1.0455,1.0556,0.9975,0.9837,1.0047,1.0102,1.0137,...,1.0502,1.0385,1.0865,1.0133,1.0189,1.0016,1.0193,1.0237,1.022,1.0007
7,20240223,0.9954,1.0049,1.0586,1.0526,1.0099,0.9763,1.0142,0.9992,1.045,...,1.0524,1.0589,1.0221,0.9956,1.0253,1.0271,1.0641,1.0241,1.0231,1.0149
8,20240226,0.9705,0.9784,1.0335,1.05,0.9902,1.0267,1.0047,0.9941,1.0216,...,1.0315,1.0795,1.0,0.9781,1.0107,1.0133,1.0014,1.0186,1.012,1.014
9,20240227,0.9972,1.008,1.0998,1.0476,1.0248,0.9953,1.0093,1.0101,1.0127,...,1.0347,1.0103,1.0341,1.0022,1.046,1.02,1.0274,1.0308,1.0461,1.0355


In [44]:
def trans_frame_to_id(dataframe):
    train_data = dataframe.iloc[:, 1:]
    # 对所有行，都取前10个最大的
    def top_n(row, n):
        # return row.nlargest(n).values
        return row.nlargest(n).index.tolist()

    n = 20
    data_top_10 = train_data.apply(top_n, axis=1, n=n)

    # 将结果转换为 [266, 10] 的形状
    data_transformed = pd.DataFrame(data_top_10.tolist(), index=train_data.index)

    def to_id(row):
        return encode(row)
    
    data_transformed = data_transformed.apply(to_id, axis=1)
    data_transformed = torch.stack([torch.tensor(row) for row in data_transformed])
    return data_transformed

train_data = trans_frame_to_id(pd_train_data)
val_data = trans_frame_to_id(pd_val_data)

val_data

tensor([[2811, 1943, 2157, 2613, 1868, 2093, 1840, 2149, 2411, 1964, 1873, 2590,
         2386, 2230, 2096, 2329, 2536, 1905, 2655, 2069],
        [2262, 2613, 1566, 1700, 2584, 2817, 1909, 2716, 1624, 1715, 1604, 2783,
         2350, 1587, 1844, 2042, 2004, 1632, 1653, 1840],
        [1949, 1939, 1955, 1695, 1762, 2094, 2369, 2452, 2446, 2030, 2079, 2143,
         2409, 2757, 1576, 1827, 1856, 2168, 2437, 2448],
        [1822, 1897, 2186, 1845, 2051, 2085, 2686, 2082, 2307, 2505, 2606, 1566,
         1797, 1872, 1979, 1983, 2025, 2097, 2246, 2407],
        [1972, 1783, 1882, 2051, 1845, 2252, 2573, 2694, 2186, 2223, 2838, 1714,
         2246, 2716, 2728, 1800, 1566, 1746, 1838, 1822],
        [1772, 1652, 1648, 1534, 2523, 2716, 1613, 2246, 2394, 2728, 2554, 2005,
         2245, 1566, 1701, 1728, 2407, 2519, 1908, 2454],
        [1701, 1761, 2182, 2365, 2728, 1681, 1702, 2357, 2671, 2772, 1740, 2309,
         2200, 1850, 1780, 2110, 1946, 1601, 2558, 2621],
        [1577, 1691, 1963, 

In [45]:
torch.manual_seed(333)

block_size = 5

def get_batch(split, i):
    data = train_data if split == 'train' else val_data

    indices = torch.randint(len(data)-1-block_size, (1, ))
    indices = torch.tensor([i])

    # (batch, block)
    x = torch.stack([data[i:i+block_size] for i in indices])
    x = x.gather(2, torch.randint(x.shape[2], (x.shape[0], x.shape[1], 1))).squeeze(-1)

    # (batch, block)
    y = torch.stack([data[i+1:i+1+block_size] for i in indices])
    y = y.gather(2, torch.randint(y.shape[2], (y.shape[0], y.shape[1], 1))).squeeze(-1)


    return x, y

get_batch('val', 2)

(tensor([[1856, 1872, 2694, 2246, 1780]]),
 tensor([[2246, 1838, 2394, 2309, 2085]]))

In [49]:

index = 2
x, y = get_batch('val', index)

print(f'print x')
print(f'x is {x}')
for i in range(block_size):
    print(f'date={pd_val_data.iloc[index+i, 0]}, chg={pd_val_data.loc[index+i, decode(x[0][i].item())]:<6}, code={decode(x[0][i].item())}, code_id={x[0][i].item()}')


idx = model.generate(x, 1)
print('-------')
print(f'predict idx is {idx}')

for i in range(block_size + 1):
    print(f'date={pd_val_data.iloc[index+i, 0]}, chg={pd_val_data.loc[index+i, decode(idx[0][i].item())]:<6}, code={decode(idx[0][i].item())}, code_id={idx[0][i].item()}')





print x
x is tensor([[2446, 2505, 1783, 2394, 1701]])
date=20240208, chg=1.2002, code=300990.SZ, code_id=2446
date=20240219, chg=1.2001, code=301052.SZ, code_id=2505
date=20240220, chg=1.2004, code=300293.SZ, code_id=1783
date=20240221, chg=1.2   , code=300935.SZ, code_id=2394
date=20240222, chg=1.2021, code=300209.SZ, code_id=1701
-------
predict idx is tensor([[2446, 2505, 1783, 2394, 1701, 2404]])
date=20240208, chg=1.2002, code=300990.SZ, code_id=2446
date=20240219, chg=1.2001, code=301052.SZ, code_id=2505
date=20240220, chg=1.2004, code=300293.SZ, code_id=1783
date=20240221, chg=1.2   , code=300935.SZ, code_id=2394
date=20240222, chg=1.2021, code=300209.SZ, code_id=1701
date=20240223, chg=1.0456, code=300946.SZ, code_id=2404
