## Shared

In [None]:
import json
import math
import os
import random
from typing import Dict, List, Tuple
import numpy as np
import sacrebleu
import torch
import torch.optim as optim
import wandb
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import io
import zstandard as zstd
import sys


notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

%load_ext autoreload
%autoreload 2


from model.model import LinguSign


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_npz_zst(path: str) -> dict:
    dctx = zstd.ZstdDecompressor()

    with open(path, 'rb') as f:
        with dctx.stream_reader(f) as reader:
            decompressed = reader.read()

    buf = io.BytesIO(decompressed)

    with np.load(buf) as data:
        return {k: data[k] for k in data.files}


class SLTDataset(Dataset):
    def __init__(self, paths: List[str], include_text=False):
        self.paths = list(paths)
        self.include_text = include_text

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i: int) -> Dict:
        path = self.paths[i]
        xy = load_npz_zst(path)

        assert (
            xy['X'].ndim == 5
            and xy['X'].shape[1] == 1024
            and xy['X'].shape[2] == 1
            and xy['X'].shape[3] == 7
            and xy['X'].shape[4] == 7
        ), f'Expected (T,1024,1,7,7), got {tuple(xy["X"].shape)} @ {path}'

        xy['X'] = torch.from_numpy(xy['X']).float()
        xy['Y'] = torch.from_numpy(xy['Y'])
        xy['text'] = ''

        if self.include_text:
            with open(path.replace('.npz.zst', '.txt'), encoding='utf-8') as f:
                xy['text'] = f.read()

        return xy


def collate_batch(
    batch: List[Dict],
    max_vis_tokens: int,
    max_text_tokens: int,
    pad_token_id: int,
    pad_to_multiple_of: int = 8,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[str]]:
    B = len(batch)

    vis_lens = [min(item['X'].shape[0], max_vis_tokens) for item in batch]
    T = min(max(vis_lens), max_vis_tokens)
    if pad_to_multiple_of:
        T = min(
            math.ceil(T / pad_to_multiple_of) * pad_to_multiple_of,
            max_vis_tokens,
        )

    vis_batch = torch.zeros(B, T, 1024, 1, 7, 7, dtype=torch.float32)
    vis_lengths = torch.tensor(vis_lens, dtype=torch.long)

    text_lens = [min(item['Y'].shape[0], max_text_tokens) for item in batch]
    L = max(text_lens)

    text_ids = torch.full((B, L), pad_token_id, dtype=torch.long)
    text_lengths = torch.tensor(text_lens, dtype=torch.long)

    text_batch = []

    for i, item in enumerate(batch):
        x = item['X']
        t_i = min(x.shape[0], T)
        vis_batch[i, :t_i] = x[:t_i]

        y = item['Y']
        l_i = min(y.shape[0], L, max_text_tokens)
        text_ids[i, :l_i] = y[:l_i]

        text_batch.append(item['text'])

    return vis_batch, vis_lengths, text_ids, text_lengths, text_batch


def split_paths(paths: List[str], seed: int, ratios=(0.90, 0.05, 0.05)):
    assert abs(sum(ratios) - 1.0) < 1e-6
    rng = random.Random(seed)
    paths = list(paths)
    rng.shuffle(paths)
    n = len(paths)
    n_train = int(round(ratios[0] * n))
    n_val = int(round(ratios[1] * n))
    train = paths[:n_train]
    val = paths[n_train : n_train + n_val]
    test = paths[n_train + n_val :]
    return train, val, test


def corpus_bleu(refs: List[str], hyps: List[str]) -> float:
    bleu = sacrebleu.corpus_bleu(hyps, [refs], smooth_method='exp')
    return bleu.score


def get_paths() -> List[str]:
    file_paths = []

    for root, _, files in os.walk('..'):
        for file in files:
            if file.endswith('.npz.zst'):
                file_paths.append(os.path.join(root, file))

    return file_paths


def run_eval(
    model: LinguSign,
    loader: DataLoader,
    max_new_tokens: int,
    generate: bool,
) -> Dict:
    model.eval()
    all_refs, all_hyps = [], []
    ce_losses = []
    with torch.no_grad():
        for vis, vis_lengths, text_ids, text_lengths, texts in tqdm(
            loader, desc='Eval', leave=False
        ):
            if generate:
                hyps = model.generate(
                    vis_tokens=vis,
                    vis_lengths=vis_lengths,
                    max_new_tokens=96,
                    do_sample=False,
                    num_beams=1,
                )
                all_hyps.extend(hyps)
                all_refs.extend(texts)

            out = model(
                vis_tokens=vis,
                vis_lengths=vis_lengths,
                text_ids=text_ids,
                text_lengths=text_lengths,
                in_warmup=False,
            )
            if 'ce_loss' in out:
                ce_losses.append(float(out['ce_loss'].item()))
            elif 'loss' in out:
                ce_losses.append(float(out['loss'].item()))

    bleu = 0
    if generate:
        bleu = corpus_bleu(all_refs, all_hyps)
    ce = float(np.mean(ce_losses)) if ce_losses else 0.0
    return {'bleu4': bleu, 'ce': ce}


def _mk_loader(ds, shuffle: bool, batch_size: int):
    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=4,
        pin_memory=(model.device.type == 'cuda'),
        collate_fn=lambda b: collate_batch(
            b,
            max_vis_tokens=config['max_vis_tokens'],
            max_text_tokens=config['max_text_tokens'],
            pad_token_id=model.tokenizer.pad_token_id,
            pad_to_multiple_of=8,
        ),
        drop_last=False,
    )


seed = 42

set_seed(seed)

with open('config.json') as f:
    config = json.load(f)

all_paths = get_paths()
assert len(all_paths) > 0, 'No paths found.'

train_paths, val_paths, test_paths = split_paths(
    all_paths, seed=seed, ratios=(0.90, 0.05, 0.05)
)
print(
    f'Dataset sizes -> train: {len(train_paths)}  val: {len(val_paths)}  test: {len(test_paths)}'
)

dset_train = SLTDataset(train_paths)
dset_val = SLTDataset(val_paths, include_text=True)
dset_test = SLTDataset(test_paths, include_text=True)

model = LinguSign()
# print(list(model.mt5.named_parameters()))

BATCH_SIZE = config['batch_size']

## Train


In [None]:
def train_loop(
    model: LinguSign,
    optimizer,
    scheduler,
    vt_warmup_loader: DataLoader,
    train_loader: DataLoader,
    val_loader: DataLoader,
    test_loader: DataLoader,
    max_new_tokens: int,
    epochs: int,
    grad_accum_steps: int,
    vt_warmup_steps: int,
    vt_warmup_grad_accum_steps: int,
    val_every_n_epochs: int,
    out_dir: str,
    wandb_run=None,
):
    os.makedirs(out_dir, exist_ok=True)

    best_bleu4 = 0
    global_step = 0

    if vt_warmup_steps > 0:
        print(
            f'Starting VT warmup for {vt_warmup_steps} steps '
            f'(batch_size={vt_warmup_loader.batch_size})'
        )
        model.train()
        optimizer.zero_grad(set_to_none=True)

        vt_warmup_steps_done = 0
        running = {'loss': 0.0, 'ce': 0.0, 'vt': 0.0}
        microbatch_idx = 0

        pbar = tqdm(
            total=vt_warmup_steps,
            desc='VT warmup',
            dynamic_ncols=True,
        )

        while vt_warmup_steps_done < vt_warmup_steps:
            for vis, vis_lengths, text_ids, text_lengths, _ in vt_warmup_loader:
                out = model(
                    vis_tokens=vis,
                    vis_lengths=vis_lengths,
                    text_ids=text_ids,
                    text_lengths=text_lengths,
                    in_warmup=True,
                )
                loss = out['loss'] if isinstance(out, dict) else out
                running['loss'] += float(loss.item())
                loss = loss / vt_warmup_grad_accum_steps
                loss.backward()

                if 'ce_loss' in out:
                    running['ce'] += float(out['ce_loss'].item())
                if 'vt_loss' in out:
                    running['vt'] += float(out['vt_loss'].item())

                microbatch_idx += 1
                if microbatch_idx % vt_warmup_grad_accum_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad(set_to_none=True)
                    scheduler.step()
                    global_step += 1
                    vt_warmup_steps_done += 1

                    avg_loss = running['loss'] / max(1, vt_warmup_steps_done)
                    avg_ce = running['ce'] / max(1, vt_warmup_steps_done)
                    avg_vt = running['vt'] / max(1, vt_warmup_steps_done)

                    pbar.update(1)
                    pbar.set_postfix(
                        loss=f'{avg_loss:.4f}',
                        ce=f'{avg_ce:.4f}',
                        vt=f'{avg_vt:.4f}',
                        gs=global_step,
                    )

                    if wandb_run is not None:
                        wandb_run.log(
                            {
                                'warmup/loss': avg_loss,
                                'warmup/ce': avg_ce,
                                'warmup/vt': avg_vt,
                                'lr': optimizer.param_groups[0]['lr'],
                                'global_step': global_step,
                            },
                            step=global_step,
                        )

                    if vt_warmup_steps_done >= vt_warmup_steps:
                        break

            if vt_warmup_steps_done >= vt_warmup_steps:
                break

        pbar.close()
        print('VT warmup finished.')

    for epoch in range(1, epochs + 1):
        model.train()
        running = {'loss': 0.0, 'ce': 0.0, 'vt': 0.0}
        optimizer.zero_grad(set_to_none=True)

        pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{epochs}', dynamic_ncols=True)
        for step, (vis, vis_lengths, text_ids, text_lengths, _) in enumerate(
            pbar, start=1
        ):
            out = model(
                vis_tokens=vis,
                vis_lengths=vis_lengths,
                text_ids=text_ids,
                text_lengths=text_lengths,
                in_warmup=False,
            )
            loss = out['loss'] if isinstance(out, dict) else out
            running['loss'] += float(loss.item())
            loss = loss / grad_accum_steps
            loss.backward()

            if 'ce_loss' in out:
                running['ce'] += float(out['ce_loss'].item())
            if 'vt_loss' in out:
                running['vt'] += float(out['vt_loss'].item())

            avg_loss = running['loss'] / step
            avg_ce = running['ce'] / max(1, step)
            avg_vt = running['vt'] / max(1, step)
            pbar.set_postfix(
                loss=f'{avg_loss:.4f}',
                ce=f'{avg_ce:.4f}',
                vt=f'{avg_vt:.4f}',
                gs=global_step,
            )

            if step % grad_accum_steps == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()
                global_step += 1

                if wandb_run is not None:
                    wandb_run.log(
                        {
                            'train/loss': avg_loss,
                            'train/ce': avg_ce,
                            'train/vt': avg_vt,
                            'lr': optimizer.param_groups[0]['lr'],
                            'epoch': epoch,
                            'global_step': global_step,
                        },
                        step=global_step,
                    )

        remainder = step % grad_accum_steps  # type: ignore
        if remainder != 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
            global_step += 1

            if wandb_run is not None:
                avg_loss = running['loss'] / step
                avg_ce = running['ce'] / max(1, step)
                avg_vt = running['vt'] / max(1, step)
                wandb_run.log(
                    {
                        'train/loss': avg_loss,
                        'train/ce': avg_ce,
                        'train/vt': avg_vt,
                        'lr': optimizer.param_groups[0]['lr'],
                        'epoch': epoch,
                        'global_step': global_step,
                    },
                    step=global_step,
                )

        if epoch % val_every_n_epochs != 0:
            continue

        val_res = run_eval(
            model,
            val_loader,
            max_new_tokens=max_new_tokens,
            generate=True,
        )
        print(
            f'\n[VAL] epoch={epoch}  BLEU4={val_res["bleu4"]:.2f}  CE={val_res["ce"]:.4f}'
        )

        if wandb_run is not None:
            wandb_run.log(
                {
                    'val/bleu4': val_res['bleu4'],
                    'val/ce': val_res['ce'],
                    'epoch': epoch,
                    'global_step': global_step,
                },
                step=global_step,
            )

        last_path = os.path.join(out_dir, 'last.pt')
        torch.save(
            {
                'model_state': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'epoch': epoch,
                'global_step': global_step,
                'val': val_res,
            },
            last_path,
        )

        peft_dir = os.path.join(out_dir, 'peft_last')
        os.makedirs(peft_dir, exist_ok=True)
        model.mt5.save_pretrained(peft_dir)

        if val_res['bleu4'] > best_bleu4:
            best_bleu4 = val_res['bleu4']
            best_path = os.path.join(out_dir, 'best.pt')
            torch.save(
                {
                    'model_state': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'epoch': epoch,
                    'global_step': global_step,
                    'val': val_res,
                },
                best_path,
            )
            peft_dir = os.path.join(out_dir, 'peft_best')
            os.makedirs(peft_dir, exist_ok=True)
            model.mt5.save_pretrained(peft_dir)
            print(f'[CKPT] Saved best -> {best_path}')

    test_res = run_eval(
        model,
        test_loader,
        max_new_tokens=max_new_tokens,
        generate=True,
    )
    print(f'\n[TEST] BLEU4={test_res["bleu4"]:.2f}  CE={test_res["ce"]:.4f}')

    if wandb_run is not None:
        wandb_run.log(
            {
                'test/bleu4': test_res['bleu4'],
                'test/ce': test_res['ce'],
                'global_step': global_step,
            },
            step=global_step,
        )


run = wandb.init(
    project='lingusign-slt',
    config=config,
)

out_dir = os.path.join('runs', run.name)  # type: ignore


###
# hyper params
###

VT_WARMUP_BATCH_SIZE = config['vt_warmup_batch_size']

vt_warmup_loader = _mk_loader(dset_train, shuffle=True, batch_size=VT_WARMUP_BATCH_SIZE)
train_loader = _mk_loader(dset_train, shuffle=True, batch_size=BATCH_SIZE)
val_loader = _mk_loader(dset_val, shuffle=False, batch_size=BATCH_SIZE)
test_loader = _mk_loader(dset_test, shuffle=False, batch_size=BATCH_SIZE)

VT_WARMUP_STEPS = config['vt_warmup_steps']

GRAD_ACCUM_STEPS = config['grad_accum_steps']
num_batches = len(train_loader)
STEPS_PER_EPOCH = math.ceil(num_batches / GRAD_ACCUM_STEPS)

OPTIMIZER = optim.AdamW(
    model.parameters(),
    lr=config['max_lr'],
    betas=config['betas'],
    weight_decay=config['weight_decay'],
)

LR_WARMUP_STEPS = config['lr_warmup_steps']

LR_WARMUP = optim.lr_scheduler.LinearLR(
    OPTIMIZER, start_factor=1e-8, end_factor=1.0, total_iters=LR_WARMUP_STEPS
)

EPOCHS = config['epochs']
TOTAL_STEPS = VT_WARMUP_STEPS + STEPS_PER_EPOCH * EPOCHS
assert TOTAL_STEPS > LR_WARMUP_STEPS, 'TOTAL_STEPS must exceed LR_WARMUP_STEPS'

COSINE = optim.lr_scheduler.CosineAnnealingLR(
    OPTIMIZER, T_max=TOTAL_STEPS - LR_WARMUP_STEPS, eta_min=config['min_lr']
)
SCHEDULER = optim.lr_scheduler.SequentialLR(
    OPTIMIZER, schedulers=[LR_WARMUP, COSINE], milestones=[LR_WARMUP_STEPS]
)

VT_WARMUP_GRAD_ACCUM_STEPS = max(
    1,
    int(round(BATCH_SIZE * GRAD_ACCUM_STEPS / VT_WARMUP_BATCH_SIZE)),
)

print(f'batch_size: train={BATCH_SIZE}, vt_warmup={VT_WARMUP_BATCH_SIZE}')
print(f'grad_accum: train={GRAD_ACCUM_STEPS}, vt_warmup={VT_WARMUP_GRAD_ACCUM_STEPS}')

train_loop(
    model=model,
    optimizer=OPTIMIZER,
    scheduler=SCHEDULER,
    vt_warmup_loader=vt_warmup_loader,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    max_new_tokens=config['max_text_tokens'],
    epochs=EPOCHS,
    grad_accum_steps=GRAD_ACCUM_STEPS,
    vt_warmup_steps=VT_WARMUP_STEPS,
    vt_warmup_grad_accum_steps=VT_WARMUP_GRAD_ACCUM_STEPS,
    val_every_n_epochs=config['val_every_n_epochs'],
    out_dir=out_dir,
    wandb_run=run,
)

run.finish()

## Test


In [None]:
checkpoint = torch.load('runs/summer-totem-1/last.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state'])

test_loader = _mk_loader(dset_test, shuffle=False, batch_size=BATCH_SIZE)

test_res = run_eval(
    model,
    test_loader,
    max_new_tokens=config['max_text_tokens'],
    generate=True,
)
print(f'\n[TEST] BLEU4={test_res["bleu4"]:.2f}  CE={test_res["ce"]:.4f}')