In [1]:
import os, sys, json, random, pytz
os.chdir("../scripts/s4")

import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from tqdm.auto import tqdm, trange

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from model import S4Model, LSTMModel  # Can use full version instead of minimal S4D standalone below
from config import *
sys.path.append("../")
from causal_transformer.config_taskspecific import *
from causal_transformer.utils import trim_task
from causal_transformer.dataset import sequences_collator
from functools import partial
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets



In [3]:
device = 'cuda'
task = "counting_samesymbol_plain3_addbigram_nullseq"
config = eval(f"{task}_Config()")

In [4]:
model = eval(f"{config.model}Model")(config)
model = model.to(device)
if device == 'cuda': cudnn.benchmark = True



In [7]:
model.load_state_dict(torch.load("/data/yingshac/llms_do_math/scripts/s4/output/0506_152821/ckpts/3_25000_LSTM.pt"
    , map_location=device), strict=True)

<All keys matched successfully>

In [8]:
model.encoder.weight_train

Parameter containing:
tensor([[-1.2166e-03,  4.7733e-04,  1.0582e-03,  ...,  3.2001e-04,
         -1.0826e-03, -1.2724e-03],
        [ 3.5081e-01, -3.0018e-02,  7.6148e-01,  ...,  1.0225e+00,
          8.0272e-01,  4.3185e-01],
        [-2.1265e-03, -6.3462e-04,  3.1366e-03,  ...,  1.8967e-03,
          1.7241e-03, -2.5580e-04],
        ...,
        [-9.2113e-04,  4.9682e-04,  1.3838e-03,  ...,  1.0064e-03,
          8.7424e-04,  1.6728e-03],
        [ 7.2630e-05,  1.6659e-03,  2.2529e-03,  ...,  9.2221e-04,
         -2.3593e-03, -2.6002e-03],
        [ 7.2826e-03, -9.4100e-03, -1.0196e-02,  ..., -1.7992e-02,
          1.7425e-02,  1.8791e-03]], device='cuda:0', requires_grad=True)

In [9]:
model.encoder.weight_freeze

Parameter containing:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0')

In [6]:
tasks = [trim_task(task)] + config.aux_tasks

train_data = concatenate_datasets(
                    [load_dataset(
                            "text", 
                            data_files={"train": f"{config.train_data_path}/{task}/train.txt"}
                            )['train'] for task in tasks]
                )
val_data = load_dataset(
                    "text", 
                    data_files={"validation": f"{config.eval_data_path}/{trim_task(task)}/val.txt"}
                    )['validation']

# args.max_seen_len = max([len([x for x in json.loads(l['text'])[0] if x != "<pad>"]) for l in val_data])
# Print(f"max_seen_len for {args.task} = {args.max_seen_len}")

collator = partial(sequences_collator, 
                    w2i={w:i for i,w in enumerate(config.vocab)}, 
                    max_seq_len=config.max_seq_len,
                    max_position_embeddings=config.max_seq_len,
                    augmentation=None,
                )

train_dataloader = DataLoader(train_data, shuffle=True, batch_size=config.per_device_train_batch_size, collate_fn=collator)
val_dataloader = DataLoader(val_data, shuffle=False, batch_size=config.per_device_eval_batch_size, collate_fn=collator)

In [7]:
def setup_optimizer(model, lr, weight_decay, epochs):
    """
    S4 requires a specific optimizer setup.

    The S4 layer (A, B, C, dt) parameters typically
    require a smaller learning rate (typically 0.001), with no weight decay.

    The rest of the model can be trained with a higher learning rate (e.g. 0.004, 0.01)
    and weight decay (if desired).
    """

    # All parameters in the model
    all_parameters = list(model.parameters())

    # General parameters don't contain the special _optim key
    params = [p for p in all_parameters if not hasattr(p, "_optim")]

    # Create an optimizer with the general parameters
    optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)

    # Add parameters with special hyperparameters
    hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")]
    hps = [
        dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
    ]  # Unique dicts
    for hp in hps:
        params = [p for p in all_parameters if getattr(p, "_optim", None) == hp]
        optimizer.add_param_group(
            {"params": params, **hp}
        )

    # Create a lr scheduler
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    # Print optimizer info
    keys = sorted(set([k for hp in hps for k in hp.keys()]))
    for i, g in enumerate(optimizer.param_groups):
        group_hps = {k: g.get(k, None) for k in keys}
        print(' | '.join([
            f"Optimizer group {i}",
            f"{len(g['params'])} tensors",
        ] + [f"{k} {v}" for k, v in group_hps.items()]))

    return optimizer, scheduler


In [8]:
criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer, scheduler = setup_optimizer(
    model, lr=config.learning_rate, weight_decay=config.weight_decay, epochs=config.num_epochs
)

Optimizer group 0 | 22 tensors | lr 0.01 | weight_decay 0.01
Optimizer group 1 | 4 tensors | lr 0.001 | weight_decay 0.01
Optimizer group 2 | 20 tensors | lr 0.001 | weight_decay 0.0


In [10]:
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
def train():
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
    for batch_idx, batch in pbar:
        inputs, targets = batch['input_id'], batch['label']
        inputs, targets = inputs.to(device), targets.to(device).view(-1)
        optimizer.zero_grad()
        logits = model(inputs) # B, seq_len, vocab_size
        logits = logits.view(-1, logits.size(-1))
        loss = criterion(
            logits, # bs*seq_len, vocab_size
            targets,
        )
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = logits.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (batch_idx, len(train_dataloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)
        )

def eval(epoch, dataloader, checkpoint=False):
    global best_acc
    model.eval()
    eval_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader), total=len(dataloader))
        for batch_idx, batch in pbar:
            inputs, targets = batch['input_id'], batch['label']
            inputs, targets = inputs.to(device), targets.to(device).view(-1)
            logits = model(inputs) # B, seq_len, vocab_size
            logits = logits.view(-1, logits.size(-1))
            loss = criterion(
                logits, # bs*seq_len, vocab_size
                targets,
            )

            eval_loss += loss.item()
            _, predicted = logits.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            pbar.set_description(
                'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total)
            )

    # Save checkpoint.
    if checkpoint:
        acc = 100.*correct/total
        if acc > best_acc:
            state = {
                'model': model.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/ckpt.pth')
            best_acc = acc

        return acc


In [None]:
#pbar = tqdm(range(start_epoch, config.num_epochs))
for epoch in start_epoch, config.num_epochs:
    print(f"start epoch {epoch}")
    # if epoch == 0:
    #     pbar.set_description('Epoch: %d' % (epoch))
    # else:
    #     pbar.set_description('Epoch: %d | Val acc: %1.3f' % (epoch, val_acc))
    train()
    val_acc = eval(epoch, val_dataloader, checkpoint=True)
    #eval(epoch, testloader)
    scheduler.step()


## Draft

In [4]:
from tqdm import tqdm
import sys

In [5]:
bar = tqdm(range(10), file=sys.stderr)


  0%|          | 0/10 [01:17<?, ?it/s]


In [6]:
bar.update(10)



True

In [None]:
class LitProgressBar(TQDMProgressBar):
...     def init_validation_tqdm(self):
...         bar = super().init_validation_tqdm()
...         bar.set_description('running validation ...')
...         return bar