In [None]:
from ..data_util import *
from ..textencoder import Encoder
from ...data2vec import Data2Vec

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

import os


In [None]:
def save_checkpoint(model, path, epoch_num):
    if not os.path.exists(path):
        os.makedirs(path)
    path = os.path.join(path, f'ckpt{epoch_num}.pt')
    checkpoint = {
        'data2vec': model.state_dict(),
        'encoder': model.encoder.encoder.state_dict()
      }
    torch.save(checkpoint, path)

In [None]:
# set up system config stuff
device = 'cuda:0'
num_epochs = 20
ckpt_dir = './checkpoints/roberta-base'

# Model, Criterion, Optimizer
encoder = Encoder('roberta-base')
model = Data2Vec(encoder=encoder, modality='text', embed_dim=768, ema_decay=0.9998, ema_end_decay=0.9999, ema_anneal_end_step=300000, device=device)
optimizer = optim.AdamW(model.parameters(), 2e-3)
criterion = nn.SmoothL1Loss(reduction='none', beta=2)
criterion.to(device)

# Datasets & Data Loaders
train_dataset = WikiDataset(device, './wikitext-103-v1')
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# Trackers
loss_tracker = AverageMeter('loss')

In [None]:
# train model, then save
model.train()
num_epochs = 20
for epoch in range(1, num_epochs + 1):
    with tqdm(train_loader, unit="batch", desc=f'Epoch: {epoch} ',
          bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator:
        for batch in iterator:

            # get data, move to device
            src, trg, mask = batch
            src = src.to(device)
            trg = trg.to(device)
            mask = mask.to(device)

            # pass through model
            x, y = model(src, trg, mask)
            loss = criterion(x.float(), y.float()).sum(dim=-1).sum().div(x.size(0))
            # update parameters
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # get loss, update teacher model, keep track of loss
            loss_tracker.update(loss.item())
            model.ema_step()
            iterator.set_postfix(loss=loss_tracker.avg)
            
    if epoch%5==0:
        save_checkpoint(model, './checkpoints/', epoch)