In [None]:
from odeformerplus.all import *

In [None]:
reset_seed(42)
DEVICE = get_device()

In [None]:
ds_train = SymbolicRegressionDataset('dataset/10k/data_all.json')

## stage 1

In [None]:
dl_train1 = DataLoader(ds_train, batch_size=16, shuffle=True, collate_fn=collate_fn_lines)

In [None]:
stage1 = ODEFormerPlusStage1(actn=nn.SiLU(), device=DEVICE)
count_params(stage1);

In [None]:
opt1 = torch.optim.Adam(stage1.parameters(), lr=1e-4)

In [None]:
def train1(epochs, last_step=0, print_every=200, save_every=1000, save_path='model_ckpt/stage1/'):
    step = last_step
    for epoch in tqdm(range(epochs), desc='Epoch'):
        for batch in tqdm(dl_train1, total=len(dl_train1), desc='Train', leave=False):
            loss = stage1.train_step(batch['lines'], batch['odes'], opt1, contrastive_loss_weight=1)
            # sch.step()
            step += 1
            
            if step % print_every == 0:
                print(f'[{epoch}|{step}] loss: {loss['loss']:.2e}, cl: {loss['loss_cl']:.2e}', end=' ')
                print(f'code: {loss['loss_code']:.2e}, ce: {loss['loss_dec']:.2e}')
            if step % save_every == 0:
                path = save_path + f'{step}.pt'
                save_ckpt(stage1, path)
    
    save_ckpt(stage1, save_path + f'{step}.pt')

In [None]:
train1(20, save_path='model_ckpt/stage1/316/')

## stage 2

In [None]:
# load_ckpt(stage1, 'model_ckpt/stage1/316/12500.pt')

In [None]:
dl_train2 = DataLoader(ds_train, batch_sampler=BucketBatchSampler(ds_train, 16, 100), collate_fn=collate_fn_trajs_lines)

In [None]:
stage2 = ODEFormerPlusStage2(stage1, device=DEVICE)
count_params(stage2);

In [None]:
opt2 = torch.optim.Adam(stage2.parameters(), lr=1e-4)

In [None]:
def train2(epochs, last_step=0, print_every=200, save_every=1000, save_path='model_ckpt/stage2/'):
    step = last_step
    for epoch in tqdm(range(epochs), desc='Epoch'):
        for batch in tqdm(dl_train2, total=len(dl_train2), desc='Train', leave=False):
            loss = stage2.train_step(batch['lines'], batch['trajectories'], batch['odes'], opt2, noise_sig=0.05, drop_rate=0.2)
            # sch.step()
            step += 1
            
            if step % print_every == 0:
                print(f'[{epoch}|{step}] loss: {loss['loss']:.2e}')
            if step % save_every == 0:
                path = save_path + f'{step}.pt'
                save_ckpt(stage2, path)
    
    save_ckpt(stage2, save_path + f'{step}.pt')

In [None]:
train2(20, save_path='model_ckpt/stage2/316/')