In [2]:
"""
Sample from a trained model
"""
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from model import GPTConfig, GPT
import json

# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out-stock' # ignored if init_from is not 'resume'
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 10 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------


In [3]:
import pandas as pd

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)


model.eval()
model.to(device)


datadir = os.path.join('data', 'stock')

# meta数据
meta = {}
with open(os.path.join(datadir, 'meta.pkl'), 'r') as f:
    meta = json.load(f)
    meta_vocab_size = meta['vocab_size']
    meta_vocab_size = 4096
def decode(id):
    return meta['itos'][str(id)]
def decode_arr(ids):
    return [decode(id) for id in ids]
def encode(s):
    return [meta['stoi'][c] for c in s]

pd_train_data = pd.read_csv(os.path.join(datadir, 'train.csv')).iloc[1:,:meta_vocab_size+1]
pd_val_data = pd.read_csv(os.path.join(datadir, 'val.csv')).iloc[:,:meta_vocab_size+1]
pd_train_data



config is GPTConfig(block_size=5, vocab_size=4096, n_layer=8, n_head=8, n_embd=512, dropout=0, bias=False)
number of parameters: 27.27M


Unnamed: 0,trade_date,000001.SZ,000002.SZ,000004.SZ,000005.SZ,000006.SZ,000007.SZ,000008.SZ,000009.SZ,000010.SZ,...,603320.SH,603321.SH,603322.SH,603323.SH,603324.SH,603325.SH,603326.SH,603327.SH,603328.SH,603329.SH
1,20230104,1.0399,1.0461,1.0020,1.0000,1.0065,0.9740,1.0339,0.9943,1.0027,...,1.0483,1.0070,1.0156,1.0235,0.9767,-100.0000,1.0524,0.9876,0.9941,1.0006
2,20230105,1.0112,1.0136,0.9890,1.0214,0.9610,0.9987,0.9959,1.0091,0.9786,...,1.0079,1.0167,1.0188,0.9958,1.0061,-100.0000,0.9858,1.0094,1.0149,0.9851
3,20230106,1.0097,0.9943,0.9756,0.9895,0.9696,0.9949,0.9918,1.0016,1.0601,...,1.0079,0.9945,1.0000,0.9916,1.0052,-100.0000,1.0036,0.9926,0.9985,0.9928
4,20230109,1.0123,0.9771,1.0083,0.9947,0.9930,1.0179,0.9959,0.9927,0.9588,...,1.0211,1.0221,0.9933,1.0085,1.0104,-100.0000,0.9688,1.0050,1.0029,1.0061
5,20230110,0.9757,1.0011,0.9907,0.9894,0.9860,1.0013,0.9917,0.9951,0.9677,...,1.0031,1.0148,0.9921,0.9853,1.0054,-100.0000,0.9901,0.9900,1.0000,0.9813
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
265,20240202,0.9851,1.0053,0.9459,0.9551,1.0051,0.9814,0.9809,0.9896,0.9421,...,0.9106,0.9433,0.9560,0.9744,0.9525,0.9545,0.9303,0.9411,0.9467,1.0168
266,20240205,1.0119,0.9673,0.8996,0.9529,0.9038,0.9501,0.9463,0.9579,0.8991,...,0.9000,0.9003,0.9392,0.9880,0.8998,0.9356,0.9001,0.9001,0.9130,0.9000
267,20240206,1.0320,1.0359,0.9051,0.9506,1.0420,0.9950,1.0412,1.1000,0.9610,...,0.9091,0.9156,1.0424,1.0339,1.0262,1.0645,0.9214,1.0366,1.0486,0.9616
268,20240207,0.9928,1.0011,0.9859,0.9481,1.0161,0.9799,0.9901,1.0991,0.9036,...,0.9000,0.9136,0.9500,1.0211,0.9627,0.9671,0.9197,0.9729,1.0232,1.0999


In [4]:
pd_val_data

Unnamed: 0,trade_date,000001.SZ,000002.SZ,000004.SZ,000005.SZ,000006.SZ,000007.SZ,000008.SZ,000009.SZ,000010.SZ,...,603320.SH,603321.SH,603322.SH,603323.SH,603324.SH,603325.SH,603326.SH,603327.SH,603328.SH,603329.SH
0,20240219,1.0145,0.9781,1.08,0.9565,0.9852,1.0149,1.0145,0.995,1.0359,...,1.0999,1.0534,1.1001,1.0138,1.044,1.0,1.061,1.0639,1.0096,0.9962
1,20240220,1.001,1.0061,1.0144,1.0455,1.0025,1.049,0.9952,1.0,1.0396,...,1.0306,1.0127,1.0618,1.0045,1.0289,1.0154,1.0176,1.0062,1.0047,1.013
2,20240221,1.0998,1.0313,1.0218,1.0435,1.01,1.0023,1.0096,0.9858,1.0429,...,1.0248,1.0233,0.9799,1.0226,1.0052,1.0249,1.0565,1.001,1.0,1.0098
3,20240222,1.0093,0.9941,1.0455,1.0556,0.9975,0.9837,1.0047,1.0102,1.0137,...,1.0502,1.0385,1.0865,1.0133,1.0189,1.0016,1.0193,1.0237,1.022,1.0007
4,20240223,0.9954,1.0049,1.0586,1.0526,1.0099,0.9763,1.0142,0.9992,1.045,...,1.0524,1.0589,1.0221,0.9956,1.0253,1.0271,1.0641,1.0241,1.0231,1.0149
5,20240226,0.9705,0.9784,1.0335,1.05,0.9902,1.0267,1.0047,0.9941,1.0216,...,1.0315,1.0795,1.0,0.9781,1.0107,1.0133,1.0014,1.0186,1.012,1.014
6,20240227,0.9972,1.008,1.0998,1.0476,1.0248,0.9953,1.0093,1.0101,1.0127,...,1.0347,1.0103,1.0341,1.0022,1.046,1.02,1.0274,1.0308,1.0461,1.0355
7,20240228,0.999,0.9871,0.8997,1.0455,0.9661,0.9501,0.9679,0.9757,0.925,...,0.9001,0.9242,1.0039,0.9888,0.9082,0.951,0.9015,0.9738,0.9332,0.9762
8,20240229,1.0095,1.0111,1.0246,1.0543,1.0175,0.995,1.019,1.0206,1.0586,...,1.0155,1.0268,1.0589,1.0136,1.0561,1.0302,1.0355,1.0998,1.0472,1.0308
9,20240301,0.9906,0.9851,1.0232,0.9485,1.0222,0.9874,1.0047,1.0101,0.983,...,1.0036,0.9923,1.0075,1.0,1.0191,1.04,1.0057,1.1003,1.0131,0.9868


In [5]:
def trans_frame_to_id(dataframe):
    train_data = dataframe.iloc[:, 1:]
    # 对所有行，都取前10个最大的
    def top_n(row, n):
        # return row.nlargest(n).values
        return row.nlargest(n).index.tolist()

    n = 20
    data_top_10 = train_data.apply(top_n, axis=1, n=n)

    # 将结果转换为 [266, 10] 的形状
    data_transformed = pd.DataFrame(data_top_10.tolist(), index=train_data.index)

    def to_id(row):
        return encode(row)
    
    data_transformed = data_transformed.apply(to_id, axis=1)
    data_transformed = torch.stack([torch.tensor(row) for row in data_transformed])
    return data_transformed

train_data = trans_frame_to_id(pd_train_data)
val_data = trans_frame_to_id(pd_val_data)


In [30]:
import easyquotation
quotation = easyquotation.use('tencent') # 新浪 ['sina'] 腾讯 ['tencent', 'qq'] 

# quotation.market_snapshot(prefix=True) # prefix 参数指定返回的行情字典中的股票代码 key 是否带 sz/sh 前缀


quotation.stocks(['300890',  '300438', '300147']) 




{'300890': {'name': '翔丰华',
  'code': '300890',
  'now': 35.39,
  'close': 32.64,
  'open': 33.4,
  'volume': 15087900.0,
  'bid_volume': 8221400,
  'ask_volume': 6866500.0,
  'bid1': 0.0,
  'bid1_volume': 0,
  'bid2': 0.0,
  'bid2_volume': 0,
  'bid3': 0.0,
  'bid3_volume': 0,
  'bid4': 0.0,
  'bid4_volume': 0,
  'bid5': 0.0,
  'bid5_volume': 0,
  'ask1': 0.0,
  'ask1_volume': 0,
  'ask2': 0.0,
  'ask2_volume': 0,
  'ask3': 0.0,
  'ask3_volume': 0,
  'ask4': 0.0,
  'ask4_volume': 0,
  'ask5': 0.0,
  'ask5_volume': 0,
  '最近逐笔成交': '',
  'datetime': datetime.datetime(2024, 4, 1, 15, 0, 31),
  '涨跌': 2.75,
  '涨跌(%)': 8.43,
  'high': 35.76,
  'low': 32.64,
  '价格/成交量(手)/成交额': '35.39/150879/526367491',
  '成交量(手)': 15087900,
  '成交额(万)': 526370000.0,
  'turnover': 15.5,
  'PE': 29.42,
  'unknown': '',
  'high_2': 35.76,
  'low_2': 32.64,
  '振幅': 9.56,
  '流通市值': 34.46,
  '总市值': 38.69,
  'PB': 2.25,
  '涨停价': 39.17,
  '跌停价': 26.11,
  '量比': 1.13,
  '委差': 0.0,
  '均价': 34.89,
  '市盈(动)': 32.31,
  '市盈(静

In [6]:
torch.manual_seed(333)

block_size = 5

def get_batch(split, i):
    data = train_data if split == 'train' else val_data

    indices = torch.randint(len(data)-1-block_size, (1, ))
    indices = torch.tensor([i])

    # (batch, block)
    x = torch.stack([data[i:i+block_size] for i in indices])
    x = x.gather(2, torch.randint(x.shape[2], (x.shape[0], x.shape[1], 1))).squeeze(-1)

    # (batch, block)
    y = torch.stack([data[i+1:i+1+block_size] for i in indices])
    y = y.gather(2, torch.randint(y.shape[2], (y.shape[0], y.shape[1], 1))).squeeze(-1)


    return x, y

get_batch('val', 2)

(tensor([[2407, 1850, 2496, 2745, 2654]]),
 tensor([[2558, 1781, 1963, 2583, 2554]]))

In [24]:
from operator import itemgetter


index = 25
haha = {}

def predect():
    data_type = 'val'

    pd_data = pd_train_data if data_type == 'train' else pd_val_data
    x, y = get_batch(data_type, index)

    # print(f'x is {x}')
    idx = model.generate(x, 1)
    # print('-')
    print(f'predict idx is {idx}')
    id = idx[-1][-1].item()
    # print(id)
    haha[str(id)] = haha.get(str(id), 0) + 1

    # print(f'code is {decode_arr(idx.tolist()[0])}')

    # for i in range(block_size + 1):
    #     print(f'date={pd_data.iloc[index+i, 0]}, chg={pd_data.loc[index+i, decode(idx[0][i].item())]:<6}, code={decode(idx[0][i].item())}, code_id={idx[0][i].item()}')

for i in range(100):
    # print(f'{i}-----')
    predect()
print(haha)
sorted_items = sorted(haha.items(), key=itemgetter(1), reverse=True)
# 输出前5个最大值及其键
print("前5个最大的值及其键：")
for key, value in sorted_items[:5]:
    print(f"Key: {key}, Value: {value}")
    print(f'id={key}, code={decode(key)}')
    # print(f'date={pd_val_data.iloc[index+block_size, 0]}, chg={pd_val_data.loc[index+block_size, decode(key)]:<6}, code={decode(key)}, code_id={key}')






predict idx is tensor([[1305, 1667, 2354, 2445, 3076, 2180]])
predict idx is tensor([[1213, 1582,  927, 2189, 1809,  927]])
predict idx is tensor([[1897, 1827, 1343, 1782, 1908, 1889]])
predict idx is tensor([[2444, 2444, 1120, 2422, 2815, 2352]])
predict idx is tensor([[3912, 1915, 1752, 2162, 2223, 1946]])
predict idx is tensor([[3282, 2809, 2354, 2351, 1462, 2077]])
predict idx is tensor([[4040, 1827, 1322, 2458, 2445, 2162]])
predict idx is tensor([[1897, 2352, 2401, 2458, 1086, 1939]])
predict idx is tensor([[1379, 2656, 1115, 1889, 1908, 2574]])
predict idx is tensor([[1897, 1667,   55, 2431, 1908,  153]])
predict idx is tensor([[1305, 2050, 1170, 2540, 2223, 2271]])
predict idx is tensor([[2444, 2419, 1758, 2445,  349, 2484]])
predict idx is tensor([[1997, 2444,  505, 2242, 1809, 1997]])
predict idx is tensor([[ 720, 2656, 2276, 2037, 3076, 2197]])
predict idx is tensor([[4040, 2750, 1253, 1603, 2477, 2071]])
predict idx is tensor([[1897, 2050,  927, 1774, 2281, 2120]])
predict 