# Get accelerator device

In [None]:
import torch

device = torch.accelerator.current_accelerator(
) if torch.accelerator.is_available() else 'cpu'
print(f'using device: {device}')

# PyTorch training loop

In [None]:
import datetime

import torch
from torch import optim, nn
epoch_num = 1
train_loader = None  # torch dataset loader

model = None
learning_rate = 5e-6
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss(reduction='none')

model.train()
for epoch in range(epoch_num):
    print(f'running epoch: {epoch}')
    for batch, (X, Y, loss_mask) in enumerate(train_loader):
        # compute prediction and loss
        X = X.to(device)
        Y = Y.to(device)
        loss_mask = loss_mask.to(device)
        pred = model(X)
        loss = loss_fn(pred.view(-1, pred.size(-1)), Y.view(-1)).view(Y.size())

        loss = (loss * loss_mask).sum() / loss_mask.sum()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss = loss.item()
        print(
            f"{datetime.datetime.now()}, epoch: {epoch}, step: {batch}, loss: {loss}"
        )

# Save and load pytorch model

In [None]:
import os
import torch


def latest_checkpoint(checkpoint_dir):

    checkpoints = sorted(os.listdir(checkpoint_dir))
    if len(checkpoints) == 0:
        return None
    return os.path.join(checkpoint_dir, checkpoints[-1])

In [None]:
model = None
device = 'cuda'
optimizer = None
checkpoint_dir = None

# move to device before load state dict
model = model.to(device)

# load checkpoints
start_epoch = 0
checkpoint_path = latest_checkpoint(checkpoint_dir)
if checkpoint_path is not None:
    checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    start_epoch = int(checkpoint['epoch']) + 1
    loss = checkpoint['loss']
    print(
        f"loading checkpoint model from {checkpoint_path}, epoch: {checkpoint['epoch']}, loss: {loss}"
    )


# save checkpoints
epoch = 0
checkpoint_path = os.path.join(checkpoint_dir, f"model_{epoch:04d}.pt")
torch.save(
    {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    },
    checkpoint_path,
)


# Parse loss curve

In [None]:
import re
import pandas as pd
import matplotlib.pyplot as plt

log_path = 'pre_train.log'

loss_vals = []
with open(log_path) as f:
    for line in f:
        loss = re.search(r"loss:\s*([0-9]+\.[0-9]+)", line)
        if loss:
            loss_val = loss.group(1)
            loss_vals.append(float(loss_val))


loss_df = pd.DataFrame(loss_vals, columns=['loss'])

loss_df['loss'].plot(grid=True)