In [85]:
import torch
from torch.utils.data import DataLoader
from dataclasses import dataclass
import einops
from flax import nnx
import jax.numpy as jnp
import jax
from tokenizer.tokenizer import ChessTokenizer
from torch.utils.data import Dataset
import pandas as pd

import numpy as np

import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

## Jax Transformer Implementation

In [164]:
@dataclass
class TransformerConfig:
    debug: bool = True
    d_model: int = 768
    d_vocab: int = 1882
    d_head: int = 64
    n_layers: int = 4
    n_heads: int = 4
    ctx_len: int = 256
    stddev: float = 0.02
    d_mlp: int = d_model*4

In [3]:
class LayerNorm(nnx.Module):
    def __init__(self, cfg: TransformerConfig, eps: float = 1e-05):
        key = jax.random.PRNGKey(101)
        self.cfg = cfg
        self.d_model = self.cfg.d_model
        self.w = nnx.Param(jax.random.normal(key, (self.d_model))) # [d_model]
        self.b = nnx.Param(jnp.zeros(self.d_model,)) # [d_model]
        self.eps = eps
    
    def __call__(self, residual: jax.Array):
        # resdiual: [batch x len x d_model]
        # Make mean 0 and normalize to have variance 1
        y = (residual - jnp.mean(residual, axis=1, keepdims=True)) / (jnp.sqrt(jnp.var(residual) + self.eps))
        # Scale with learned weights
        y = y * self.w
        # Translate with learned bias
        y = y + self.b
        
        return y

In [4]:
class Embed(nnx.Module):
    def __init__(self, cfg: TransformerConfig):
        key = jax.random.PRNGKey(101)
        self.cfg = cfg
        self.W_E = nnx.Param(jax.random.normal(key, (self.cfg.d_vocab, self.cfg.d_model)) * self.cfg.stddev)

    def __call__(self, tokens: jnp.ndarray) -> jnp.ndarray:
        # tokens: [batch length]
        return self.W_E[tokens]

In [5]:
class PosEmbed(nnx.Module):
    def __init__(self, cfg: TransformerConfig):
        key = jax.random.PRNGKey(101)
        self.cfg = cfg
        self.W_pos = nnx.Param(jax.random.normal(key, (cfg.ctx_len, cfg.d_model)) * self.cfg.stddev)

    def __call__(self, tokens: jnp.ndarray) -> jnp.ndarray:
        # tokens: [batch length]
        batch, length = tokens.shape
        return einops.repeat(self.W_pos[:length], 'length d_model -> batch length d_model', batch=batch)

In [6]:
class Attention(nnx.Module):
    def __init__(self, cfg: TransformerConfig, rngs: nnx.Rngs):
        key = rngs.params()
        self.cfg = cfg
        self.W_Q = nnx.Param(jax.random.normal(key, (cfg.n_heads, cfg.d_model, cfg.d_head))) # [num_heads, d_model, d_head]
        self.W_K = nnx.Param(jax.random.normal(key, (cfg.n_heads, cfg.d_model, cfg.d_head))) # [num_heads, d_model, d_head]
        self.W_V = nnx.Param(jax.random.normal(key, (cfg.n_heads, cfg.d_model, cfg.d_head))) # [num_heads, d_model, d_head]
        self.W_O = nnx.Param(jax.random.normal(key, (cfg.n_heads, cfg.d_head, cfg.d_model))) # [num_heads, d_head, d_model]
        self.b_Q = nnx.Param(jnp.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nnx.Param(jnp.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nnx.Param(jnp.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nnx.Param(jnp.zeros((cfg.d_model)))

    def __call__(self, normal_pre_resid: jnp.ndarray) -> jnp.ndarray:
        """
        b = batch
        l = length
        m = d_model
        n = num_heads
        h = d_head
        q = q_pos
        k = k_pos
        """
        # normal_pre_resid: [batch length d_model]
        q = jnp.einsum('blm, nmh -> blnh', normal_pre_resid, self.W_Q.value) + self.b_Q
        k = jnp.einsum('blm, nmh -> blnh', normal_pre_resid, self.W_K.value) + self.b_K
        v = jnp.einsum('blm, nmh -> blnh', normal_pre_resid, self.W_V.value) + self.b_V

        attn_scores = jnp.einsum('bqnh, bknh -> bnqk', q, k)
        attn_scores = self.apply_casual_mask(attn_scores / self.cfg.d_head ** 0.5)
        attn_probs = jax.nn.softmax(attn_scores, axis=-1) # [batch x n_heads x q_pos x k_pos]

        # [batch x q_pos x n_heads x d_head]
        z = jnp.einsum('bnqk, bknh -> bqnh', attn_probs, v)

        out = jnp.einsum('bqnh, nhm -> bqnm', z, self.W_O.value)
        out = jnp.einsum('bqnm -> bqm', out) + self.b_O
        return out

    def apply_casual_mask(self, attn_scores: jnp.ndarray) -> jnp.ndarray:
        # attn_scores: [batch n_heads q_pos k_pos]
        mask = jnp.triu(attn_scores).astype(bool)
        masked_attn_scores = jnp.where(mask,jax.lax.broadcast(-jnp.inf, attn_scores.shape), attn_scores)
        
        return masked_attn_scores

In [7]:
class MLP(nnx.Module):
    def __init__(self, cfg: TransformerConfig):
        key = jax.random.PRNGKey(101)
        self.cfg = cfg
        self.W_in = nnx.Param(jax.random.normal(key, (cfg.d_model, cfg.d_mlp))) # [d_model, d_mlp]
        self.W_out = nnx.Param(jax.random.normal(key, (cfg.d_mlp, cfg.d_model))) # [d_mlp, d_model]
        self.b_in = nnx.Param(jnp.zeros((cfg.d_mlp)))
        self.b_out = nnx.Param(jnp.zeros((cfg.d_model)))

    def __call__(self, normal_resid_mid: jnp.ndarray) -> jnp.ndarray:
        # normal_resid_mid [batch x length x d_model]
        """
        b = batch
        l = length
        m = d_model
        p = d_mlp
        """
        out = jnp.einsum('blm, mp -> blp', normal_resid_mid, self.W_in.value) + self.b_in
        out = jax.nn.gelu(out)
        out = jnp.einsum('blp, pm -> blm', out, self.W_out.value) + self.b_out
        return out


In [8]:
class TransformerBlock(nnx.Module):
    def __init__(self, cfg: TransformerConfig):
        self.cfg = cfg
        self.ln1 = LayerNorm(self.cfg)
        self.ln2 = LayerNorm(self.cfg)
        self.attn = Attention(self.cfg, rngs=nnx.Rngs(params=0))
        self.mlp = MLP(self.cfg)

    def __call__(self, resid_pre: jnp.ndarray) -> jnp.ndarray:
        resid_mid = self.attn(self.ln1(resid_pre))
        resid_post = self.mlp(self.ln2(resid_pre))
        return(resid_post)

In [9]:
class Unembed(nnx.Module):
    def __init__(self, cfg: TransformerConfig):
        key = jax.random.PRNGKey(101)
        self.cfg = cfg
        self.W_U = nnx.Param(jax.random.normal(key, (cfg.d_model, cfg.d_vocab)))
        self.b_U = nnx.Param(jnp.zeros(cfg.d_vocab))

    def __call__(self, normal_resid_post: jnp.ndarray) -> jnp.ndarray:
        # normal_resid_post: [batch x length x d_model]
        """
        b = batch
        l = length
        m = d_model
        b = d_vocab
        """
        return jnp.einsum('blm, mv -> blv', normal_resid_post, self.W_U.value) + self.b_U

In [10]:
class Transformer(nnx.Module):
    def __init__(self, cfg):
        self.cfg = cfg
        self.embed = Embed(self.cfg)
        self.pos_embed = PosEmbed(self.cfg)
        self.blocks = [TransformerBlock(self.cfg) for _ in range(cfg.n_layers)]
        self.ln_final = LayerNorm(self.cfg)
        self.unembed = Unembed(self.cfg)

    def __call__(self, tokens: jnp.ndarray) -> jnp.ndarray:
        resid = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            resid = block(resid)
        logits = self.unembed(self.ln_final(resid))
        return logits

In [165]:
cfg = TransformerConfig(
    d_model=64,
)

In [12]:
def rand_float_test(cls, shape):
    random_input = jax.random.uniform(jax.random.PRNGKey(101), (shape))
    print("Input shape:", random_input.shape)
    output = cls(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

def rand_int_test(cls, shape):
    random_input = jax.random.randint(jax.random.PRNGKey(101), (shape), 100, 1000)
    print("Input shape:", random_input.shape)
    output = cls(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

In [13]:
# LayerNorm test
ln = LayerNorm(cfg)
rand_int_test(ln, (2, 4, cfg.d_model))

# Embed test
emb = Embed(cfg)
rand_int_test(emb, (2, 128))

# PosEmbed test
pos = PosEmbed(cfg)
rand_int_test(emb, (2, 128))

# Attention test
attn = Attention(cfg, rngs=nnx.Rngs(params=0))
rand_float_test(attn, (2, 128, cfg.d_model))

# MLP test
mlp = MLP(cfg)
rand_float_test(attn, (2, 128, cfg.d_model))

# TransformerBlock test
tb = TransformerBlock(cfg)
rand_float_test(attn, (2, 128, cfg.d_model))

# Unembed test
un = Unembed(cfg)
rand_float_test(attn, (2, 128, cfg.d_model))

# Transformer test
t = Transformer(cfg)
rand_int_test(emb, (2, 128))

Input shape: (2, 4, 64)
Output shape: (2, 4, 64) 

Input shape: (2, 128)
Output shape: (2, 128, 64) 

Input shape: (2, 128)
Output shape: (2, 128, 64) 

Input shape: (2, 128, 64)
Output shape: (2, 128, 64) 

Input shape: (2, 128, 64)
Output shape: (2, 128, 64) 

Input shape: (2, 128, 64)
Output shape: (2, 128, 64) 

Input shape: (2, 128, 64)
Output shape: (2, 128, 64) 

Input shape: (2, 128)
Output shape: (2, 128, 64) 



In [14]:
tokenizer = ChessTokenizer()
tokenizer.load_tokenizer("./tokenizer/vocab.json")

In [167]:
test_game = ["<|startofgame|>", "e2e4", "c7c5", "g1f3", "d7d6", "f1b5", "c8d7", "d1e2", "g8f6", "b2b3", "e7e6", "c1b2", "f8e7", "e4e5", "d6e5", "f3e5", "e8g8", "e1g1", "a7a6", "e5d7", "b8d7", "b5d3", "b7b5", "a2a4", "c5c4", "b3c4", "b5b4", "d3e4", "a8b8", "d2d3", "a6a5", "b1d2", "d7c5", "b2e5", "b8b6", "e4f3", "d8d7", "d2b3", "b6a6", "b3c5", "e7c5", "d3d4", "c5e7", "f1d1", "f8c8", "c4c5", "a6a7", "e2b5", "f6d5", "b5d7", "a7d7", "d1d3", "f7f6", "f3g4", "g8f7", "e5g3", "f6f5", "g4h5", "g7g6", "h5f3", "e7f6", "g3d6", "d7d6", "c5d6", "c8c2", "f3d5", "e6d5", "a1e1", "c2c6", "d6d7", "c6d6", "e1e8", "d6d7", "e8a8", "f6d8", "g2g3", "f7e6", "a8a6", "e6f7", "g1g2", "g6g5", "g2f3", "d8c7", "a6a7", "g5g4", "f3g2", "f7e6", "a7b7", "e6d6", "b7b5", "d7e7", "g2f1", "e7e4", "b5b7", "h7h5", "b7b5", "f5f4", "f2f3", "g4f3", "g3f4", "e4f4", "f1f2", "c7d8", "b5b7", "d8h4", "f2f1", "f3f2", "b7b6", "1-0", "<|endofgame|>"]
input_ids = tokenizer.encode(test_game)
batched_input_ids = jnp.expand_dims(input_ids, 0)

transformer = Transformer(cfg)
logits = transformer(batched_input_ids)

greedy_pred = jax.nn.softmax(logits[0,-1], axis=-1).argmax()
print(f"Next predicted move is: {tokenizer.decode([greedy_pred])[0]}")

Next predicted move is: f2d1


In [166]:
@dataclass
class TraningArgs():
    batch_size = 16
    epochs = 10
    max_steps_per_epoch = 200
    lr = 1e-3
    weight_decay = 1e-2
    wandb_project_name: str | None = "ChessTransformer"
    wandb_name: str | None = None

args = TraningArgs()

In [181]:
import mmap
import csv
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset


# Lots of inspo taken from https://github.com/codyjk/ChessGPT/blob/main/src/chess_model/data/dataset.py -- thank you :^)
class GamesDataset(Dataset):
    def __init__(self, filename: str, tokenizer, context_length=256):
        self.filename = filename
        self.tokenizer = tokenizer
        self.context_length = context_length
        self.line_offsets = []
        self.file = open(self.filename, "r")

        with open(self.filename, 'rb') as f:
            mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
            total_size = mm.size()
            self.line_offsets.append(0)

            with tqdm(
                total=total_size, unit="B", unit_scale=True, desc="Indexing CSV file"
            ) as pbar:
                while mm.readline():
                    current_pos = mm.tell()
                    self.line_offsets.append(current_pos)
                    pbar.update(current_pos - pbar.n)

            mm.close()

            self.line_offsets.pop()

    def __len__(self):
        return len(self.line_offsets) - 1

    def __getitem__(self, idx) -> dict:
        # Add 1 to idx to skip the header
        items = {
            "input_ids": jnp.empty((1, self.context_length),dtype=jnp.int32), # [batch x context_length]
            "labels": jnp.empty((1, self.context_length), dtype=jnp.int32), # [batch x context_length]
            "is_checkmate": jnp.empty((1, 1), dtype=jnp.int32),
            "outcome": jnp.empty((1,3), dtype=jnp.int32),
            "move_mask": jnp.empty((1, self.context_length), dtype=jnp.int32),
        }
        if isinstance(idx, int):
            idx = [idx]

        for i in idx:
            """
            this is a super hack bc im trying to get a line that doesnt exist
            on second thought, dont think this is super hacky since we need to + 1 to skip the label row
            but this means that if i == len(self.line_offsets) when its + 1'd it will be oob, so decrement just that 
            """
            ## TODO fix me plsss
            if i == len(self.line_offsets) - 1:
                i -= 1
            self.file.seek(self.line_offsets[i + 1])
            line = self.file.readline().strip()

            # Parse the CSV line
            row = next(csv.reader([line]))
            context, is_checkmate, outcome = row

            context = context.split() if context else []
            context, last_move = context[:-1], context[-1]
            is_checkmate = jnp.array(jnp.expand_dims(float(is_checkmate == "True"), axis=0), dtype=jnp.int32)

            input_ids = self.tokenizer.encode_and_pad(context, self.context_length)

            # Shift context to the left to create labels
            # The next move prediction for input_ids[n] is labels[n]
            labels = context[1:] + [last_move]
            labels = self.tokenizer.encode_and_pad(labels, self.context_length)

            # If white won, we want the model to learn from white's moves, not black's.
            # Conversely, if black won, we want the model to learn from black's moves.
            # For draws, we want the model to learn from both moves.
            # We will produce a mask that masks out the moves for the losing player,
            # and the model will learn from the remaining moves.
            move_mask = jnp.ones(self.context_length, dtype=jnp.int32)

            if outcome == "1-0":  # White won
                # Mask out odd-indexed moves (Black's moves)
                move_mask = move_mask.at[1::2].set(0.0)
            elif outcome == "0-1":  # Black won
                # Mask out even-indexed moves (White's moves)
                move_mask = move_mask.at[::2].set(0.0)
            # For draws (1/2-1/2), keep all moves (mask stays 1)

            # If the context is shorter than max_context_length, zero-out that part of the mask
            if len(context) < self.context_length:
                move_mask = move_mask.at[len(context) :].set(0.0)

            # Convert outcome to one-hot encoding (as float)
            outcome_label = jnp.zeros(3, dtype=jnp.int32)
            if outcome == "1-0":
                outcome_label = outcome_label.at[0].set(1.0)
            elif outcome == "0-1":
                outcome_label = outcome_label.at[1].set(1.0)
            elif outcome == "1/2-1/2":
                outcome_label = outcome_label.at[2].set(1.0)

            items["input_ids"] = jnp.concat((items["input_ids"], jnp.array(jnp.expand_dims(input_ids, axis=0), dtype=jnp.int32))) # expand dims to add empty batch dim
            items["labels"] = jnp.concat((items["labels"], jnp.array(jnp.expand_dims(input_ids, axis=0), dtype=jnp.int32)))
            items["is_checkmate"] = jnp.concat((items["is_checkmate"], jnp.array(jnp.expand_dims(is_checkmate, axis=1), dtype=jnp.int32)))
            items["outcome"] = jnp.concat((items["outcome"], jnp.expand_dims(outcome_label, axis=0)))
            items["move_mask"] = jnp.concat((items["move_mask"], jnp.expand_dims(move_mask, axis=0)))

        # kinda (very) hacky, removes the first dim since it was the empty dim when initialized
        return {
        "input_ids": items["input_ids"][1:],
        "labels": items["labels"][1:],
        "is_checkmate": items["is_checkmate"][1:],
        "outcome": items["outcome"][1:],
        "move_mask": items["move_mask"][1:],
    }

    def __del__(self):
        # Close the file when the dataset object is destroyed
        if hasattr(self, "file"):
            self.file.close() 

    def train_test_split(self, test_size: float = 0.2, random_state: int = 1234) -> dict:
        train_indicies, test_indicies = train_test_split(
            range(len(self.line_offsets)),
            test_size=test_size,
            random_state=random_state
        )
        train_dataset = Subset(self, train_indicies)
        test_dataset = Subset(self, test_indicies)
        return {
            "train": train_dataset,
            "test": test_dataset
        }

    @staticmethod
    def collate_fn(batch):
        out = batch[0]
        for d in batch[1:]:
            for k in d.keys():
                out[k] = jnp.concat((out[k], d[k]))
        return out

In [78]:
dataset = GamesDataset("chess_games.csv", tokenizer)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers = 0, collate_fn=GamesDataset.collate_fn)

# Test batching works
for i_batch, sample_batched in enumerate(dataloader):
    if i_batch == 1:
        for d in sample_batched:
            pass
        break

Indexing CSV file: 100%|████████████████████████████████████████████████████████████| 872k/872k [00:00<00:00, 818MB/s]


In [102]:
@dataclass
class TransformerTrainingArgs():
    batch_size = 16
    epochs: int = 10
    max_steps_per_epoch: int = 200
    lr = 1e-3
    weight_decay = 1e-2
    wandb_project: str | None = "ChessTransformer"
    wandb_name: str | None = None


In [None]:
import optax
import wandb

class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: Transformer, train_loader, test_loader):
        super().__init__()
        self.model = model
        self.args = args
        self.optimizer = nnx.Optimizer(self.model, optax.adamw(learning_rate=args.lr, weight_decay=args.weight_decay))
        self.step = 0
        self.train_loader = train_loader
        self.test_loader = test_loader

    def training_step(self, batch: dict) -> jnp.ndarray:
        def loss_fn(model: Transformer):
            y_pred = model(batch["input_ids"])
            log_probs = self.get_log_probs(y_pred, batch["input_ids"])
            return -log_probs.mean()

        loss, grads = nnx.value_and_grad(loss_fn)(self.model)
        self.optimizer.update(grads)

        self.step += 1
        wandb.log({"train_loss":loss}, step=self.step)
        return loss

    def validation_step(self, batch: dict) -> jnp.ndarray:
        tokens = batch["input_ids"]
        logits = self.model(tokens)[:,:-1]
        pred_tokens = jnp.argmax(logits, axis=-1)
        correct = (pred_tokens == tokens[:, 1:]).flatten()

        return correct

    def train(self):
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        accuracy = jnp.nan
        total_steps = self.args.epochs * self.args.max_steps_per_epoch

        with tqdm(total=total_steps, desc="Training Epochs") as progress_bar:
            for epoch in range(self.args.epochs):
                for i, batch in enumerate(self.train_loader):
                    loss = self.training_step(batch)
                    progress_bar.update()
                    progress_bar.set_description(f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.2f}")
                    if i >= self.args.max_steps_per_epoch:
                        break

                correct = jnp.concat([self.validation_step(batch) for batch in self.test_loader])
                accuracy = jnp.mean(correct)
                wandb.log({"accuracy":accuracy}, step=self.step)

        wandb.finish()

    # TODO write unittest for this PLEASE
    def get_log_probs(self, logits: jnp.ndarray, tokens: jnp.ndarray) -> jnp.ndarray:
        log_probs = nnx.log_softmax(logits, axis=-1)
        sliced_log_probs = log_probs[:, :-1]
        next_token_indicies = jnp.expand_dims(tokens[:, 1:], axis=-1).astype(jnp.int32)
        log_probs_for_tokens = jnp.take_along_axis(
            sliced_log_probs, next_token_indicies, axis=-1
        )

        return log_probs_for_tokens

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f7600d87290>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 7f75fe862610, raw_cell="import optax
import wandb

class TransformerTraine.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/wizard/projects/chess/transformer.ipynb#Y106sZmlsZQ%3D%3D>,),kwargs {}:


TypeError: _WandbInit._resume_backend() takes 1 positional argument but 2 were given

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f7600d87290>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f72e0f4acd0, execution_count=172 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 7f75fe862610, raw_cell="import optax
import wandb

class TransformerTraine.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/wizard/projects/chess/transformer.ipynb#Y106sZmlsZQ%3D%3D> result=None>,),kwargs {}:


TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given

In [180]:
dataset = GamesDataset("chess_games.csv", tokenizer, context_length=256)
dataset_dict = dataset.train_test_split(test_size=1000)

train_dataset = dataset_dict["train"]
test_dataset = dataset_dict["test"]

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True, collate_fn=GamesDataset.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True, collate_fn=GamesDataset.collate_fn)

args = TransformerTrainingArgs(epochs=10, max_steps_per_epoch=200)
trainer = TransformerTrainer(args, transformer, train_loader=train_loader, test_loader=test_loader)
trainer.train()

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f74d5f6c3d0>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 7f72e7aea8d0, raw_cell="dataset = GamesDataset("chess_games.csv", tokenize.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/wizard/projects/chess/transformer.ipynb#Y103sZmlsZQ%3D%3D>,),kwargs {}:


TypeError: _WandbInit._resume_backend() takes 1 positional argument but 2 were given

Indexing CSV file: 100%|████████████████████████████████████████████████████████████| 872k/872k [00:00<00:00, 797MB/s]


0,1
train_loss,▂▅▇▄█▆▁▄█▃▅▇▄▇▇█▆▁▁▅

0,1
train_loss,2.15735


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669611033285035, max=1.0…

Epoch 1, loss: 1.350, accuracy: nan:   4%|█▌                                      | 76/2000 [02:03<1:00:14,  1.88s/it]

acc: 0.7146509885787964


Epoch 2, loss: 0.959, accuracy: 0.71:   8%|███                                     | 152/2000 [05:32<36:16,  1.18s/it]

acc: 0.715843141078949


Epoch 4, loss: 1.065, accuracy: 0.72:  15%|██████                                  | 304/2000 [12:30<33:36,  1.19s/it]

acc: 0.7176784873008728


Epoch 5, loss: 1.162, accuracy: 0.72:  19%|███████▌                                | 380/2000 [15:57<29:58,  1.11s/it]

acc: 0.7187294363975525


Epoch 6, loss: 2.466, accuracy: 0.72:  23%|█████████                               | 456/2000 [19:25<30:52,  1.20s/it]

acc: 0.7190862894058228


Epoch 7, loss: 1.232, accuracy: 0.72:  27%|██████████▋                             | 532/2000 [22:53<28:51,  1.18s/it]

acc: 0.7191490530967712


Epoch 8, loss: 1.860, accuracy: 0.72:  30%|████████████▏                           | 608/2000 [26:20<26:54,  1.16s/it]

acc: 0.7189686894416809


Epoch 9, loss: 1.794, accuracy: 0.72:  34%|█████████████▋                          | 684/2000 [29:48<26:54,  1.23s/it]

acc: 0.7184706330299377


Epoch 10, loss: 1.194, accuracy: 0.72:  38%|██████████████▊                        | 760/2000 [34:45<56:42,  2.74s/it]

acc: 0.7184706330299377





VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
train_loss,▆▄█▅▅▄█▆█▅▅▆▅▅▆▅▃▅▄▆▂▄▄▄▄▄▄▆▄▄▅▃▅▃▄▅▄▃▃▁

0,1
accuracy,0.71847
train_loss,1.19364
