# Environment Setup
Install required dependencies, set up GPU, import libraries, and clone project repository.

In [None]:
# Install required dependencies
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install scipy pyyaml fire seaborn matplotlib tqdm

In [None]:
# Check GPU availability
import torch
if not torch.cuda.is_available():
    raise SystemError("GPU is not available. Please ensure you are using a Colab runtime with GPU enabled.")
else:
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")


# Import required libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import yaml
import json
import warnings
warnings.filterwarnings("ignore")

# importing required libraries
import torch.backends.cudnn as cudnn
import random

import math
from torch import autograd

# Mount Google Drive
Mount Google Drive to persist data and results between Colab sessions.

In [None]:
from google.colab import drive

# Mount Google Drive to persist data and results
drive.mount('/content/drive')

# Set the working directory to a folder in Google Drive
import os
os.makedirs('/content/drive/MyDrive/ColabNotebooks/hessian-ood', exist_ok=True)
%cd /content/drive/MyDrive/ColabNotebooks/hessian-ood

# Model Implementation
Define the Transformer model architecture including embedding layers, attention mechanisms, and feed-forward networks.

In [12]:
#####################################################
##################### model definition ####################
#####################################################


class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model, init=None, trainable=True):
        """
        Args:
            vocab_size: size of vocabulary
        """
        super(Embedding, self).__init__()
        if init is not None:
            self.embed = nn.Embedding.from_pretrained(init).requires_grad_(trainable)
        else:
            self.embed = nn.Embedding(vocab_size, d_model).requires_grad_(trainable)

    def forward(self, x):
        """
        Args:
            x: input vector
        Returns:
            out: embedding vector
        """
        out = self.embed(x)
        return out


class PositionalEmbedding(nn.Module):
    def __init__(self, max_seq_len, d_model, init=None, trainable=True):
        """
        Args:
            max_seq_len: maximium length of input sequence
        Final embedding dimension is max_seq_len + static embedding dimension
        """

        super(PositionalEmbedding, self).__init__()
        if init is not None:
            self.pe = nn.Embedding.from_pretrained(init).requires_grad_(trainable)
        else:
            self.pe = nn.Embedding(max_seq_len, d_model).requires_grad_(trainable)

    def forward(self, x):
        """
        Args:
            x: input vector
        Returns:
            out: output
        """
        # append positional encodings to static embeddings
        seq_len = x.size(1)
        batch_size = x.size(0)
        pos = torch.arange(0, seq_len, dtype=torch.long)
        out = x + self.pe(pos).repeat(batch_size, 1, 1)
        return out


## following three definitions are for rotary embeddings
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """Rotary embedding helper function"""
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (freqs_cis.shape, x.shape)
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


# The following implementation of multi-head attention is from
# https://towardsdatascience.com/build-your-own-transformer-from-scratch-using-pytorch-84c850470dcb
class MultiHeadAttention(nn.Module):
    def __init__(self, config, linear_attn=False):
        super(MultiHeadAttention, self).__init__()
        self.d_model = config.d_model
        self.num_heads = config.num_heads
        self.linear_attn = linear_attn
        self.pos = config.pos
        self.d_k = config.d_model // config.num_heads
        assert (
            config.d_model % config.num_heads == 0
        ), "d_model must be divisible by num_heads"

        self.W_q = nn.Linear(config.d_model, config.d_model)
        self.W_k = nn.Linear(config.d_model, config.d_model)
        self.W_v = nn.Linear(config.d_model, config.d_model)
        self.W_o = nn.Linear(config.d_model, config.d_model)

        if config.pos == "relative":
            self.att_bias = nn.Parameter(
                torch.zeros(config.num_heads, config.max_seq_len, config.max_seq_len)
            ).to(config.device)

        if config.pos == "rotary":
            self.freqs_cis = precompute_freqs_cis(
                self.d_model // self.num_heads,
                config.max_seq_len * 2,
                config.rotary_theta,
            ).to(config.device)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        assert mask is not None, "Mask=None is not supported now"

        if self.pos == "rotary":
            T = Q.size(2)
            # expected shape for apply_rotary_emb: (batch_size, max_seq_len, num_head, d_head)
            Q, K = apply_rotary_emb(
                Q.transpose(1, 2), K.transpose(1, 2), freqs_cis=self.freqs_cis[:T]
            )
            Q, K = Q.transpose(1, 2), K.transpose(1, 2)

        QK_vals = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if self.pos == "relative":
            T = QK_vals.size(2)
            QK_vals = QK_vals + self.att_bias[:, :T, :T].view(1, self.num_heads, T, T)

        if mask is not None:
            if not self.linear_attn:
                attn_scores = QK_vals.masked_fill(mask == 0, -1e9)
                attn_probs = torch.softmax(attn_scores, dim=-1)
            else:
                attn_scores = QK_vals.masked_fill(mask == 0, 0)
                attn_probs = attn_scores
        output = torch.matmul(attn_probs, V)
        return output, (attn_probs, QK_vals)

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None, output_attn=False):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output, attn_probs = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        output = (output, attn_probs) if output_attn else output
        return output


class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))


class PositionWiseFeedForward2(nn.Module):
    def __init__(self, d_model, d_ff, dropout=None, norm=False):
        super(PositionWiseFeedForward2, self).__init__()
        self.ln = nn.LayerNorm(d_model)
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
        self.drop = dropout
        self.norm = norm
        if dropout is not None:
            self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.ln(x) if self.norm else x
        x = self.fc2(self.relu(self.fc1(x)))
        x = self.dropout(x) if self.drop is not None else x
        return x


## main Transformer definitions
class TFBlock(nn.Module):
    def __init__(
        self,
        config,
    ):
        super().__init__()
        self.vocab_size = config.vocab_size
        self.residual = config.residual
        self.drop = config.dropout
        self.norm = config.norm
        self.ff_dim = config.ff_dim
        self.linear_attn = config.linear_attn
        self.mlp = config.mlp

        # initiating layers
        self.mha = MultiHeadAttention(config, linear_attn=config.linear_attn)
        self.ln_1 = nn.LayerNorm(config.d_model, elementwise_affine=config.trainable_norm)
        self.ln_2 = nn.LayerNorm(config.d_model, elementwise_affine=config.trainable_norm)
        self.dropout1 = nn.Dropout(config.dropout)
        self.dropout2 = nn.Dropout(config.dropout)
        if config.mlp:
            self.feed_forward = PositionWiseFeedForward(config.d_model, config.ff_dim)

    def forward(self, x, mask):
        out = self.ln_1(x) if self.norm else x
        attn_output = self.mha(out, out, out, mask)
        out = self.dropout1(attn_output) if self.drop is not None else attn_output
        x = x + out if self.residual else x
        if self.mlp:
            out = self.ln_2(x) if self.norm else x
            out = self.feed_forward(out)
            out = self.dropout2(out)
            x = x + out if self.residual else x
        return x


class TFModel(nn.Module):
    def __init__(self, config):
        super(TFModel, self).__init__()
        self.device = config.device
        self.pos = config.pos
        self.num_layers = config.num_layers
        self.output_norm = config.output_norm
        self.vocab_size = config.vocab_size
        self.embed = Embedding(config.vocab_size, config.d_model)
        if self.pos is None:
            self.pos_embed = PositionalEmbedding(config.max_seq_len, config.d_model)

        self.h = nn.ModuleList([TFBlock(config) for i in range(config.num_layers)])
        self.ln_f = nn.LayerNorm(config.d_model, elementwise_affine=config.trainable_norm)
        self.fc = nn.Linear(config.d_model, config.vocab_size)

    def forward(self, src):
        x = (
            self.embed(src)
            if self.pos in ["relative", "rotary"]
            else self.pos_embed(self.embed(src))
        )
        seq_len = x.size(1)
        mask = (
            torch.tril(torch.ones(seq_len, seq_len))
            .unsqueeze(0)
            .unsqueeze(0)
            .to(self.device)
        )
        out = x
        for i, (block) in enumerate(self.h):
            out = block(out, mask)
        out = self.ln_f(out) if self.output_norm else out
        out = self.fc(out)
        return out


# Data Generation
setup data generation functions

In [13]:
#####################################################
##################### Data generation ####################
#####################################################


def gen_simple_data(
    vocab,
    max_seq_len,
    sample_size,
    pattern="random",
    pattern_sample_len=None,
    rep_l=11,
    rep_h=20,
    return_lens=False,
):
    """
    Generate input sequences for training/testing based on different patterns.
    Simple repetitions of certain short-ranged patterns.
    Args:
        vocab: 1d torch.Tensor containing entire vocabulary
        max_seq_len: positive integer that specifies the maximum number of tokens in a sequence
        sample_size: the number of input sequences
        pattern: a string that indicated the short pattern used for generating sequence data, or a 1-d numpy array
        pattern_sample_len: the length of sampled patterns, only used when pattern ='random'
        return_lens: if True, returns repetition length for each seq
    Returns:
        data: input sequences, 2d torch.Tensor of type torch.long; if return_len, returns (data, lens)
    """
    vocab_size = vocab.size(0)
    data = torch.zeros(sample_size, max_seq_len).type(torch.LongTensor)
    lens = np.zeros(sample_size, dtype=int)
    id0, id1, id2 = 0, 1, 2
    if max_seq_len % 12 != 0:
        warnings.warn("max_seq_len is not divisible by 12, which may cause issues!")

    for i in range(sample_size):
        if pattern == "random":
            if pattern_sample_len is None:
                pattern_len = np.random.randint(low=rep_l, high=rep_h)
            pattern_sample = torch.multinomial(
                torch.ones(vocab_size) / vocab_size, pattern_len, replacement=True
            )
            num_repeat = max_seq_len // pattern_len + 1
            r = pattern_len * num_repeat - max_seq_len
            tmp = vocab[pattern_sample].repeat(num_repeat)
            start_pos = np.random.randint(low=0, high=r)
            data[i, :] = tmp[start_pos : (start_pos + max_seq_len)]
            lens[i] = pattern_len
        else:  # for a given pattern
            pattern_len = len(pattern)
            num_repeat = max_seq_len // pattern_len + 1
            r = pattern_len * num_repeat - max_seq_len
            tmp = vocab[pattern].repeat(num_repeat)
            start_pos = np.random.randint(low=0, high=r)
            data[i, :] = tmp[start_pos : (start_pos + max_seq_len)]
            # warnings.warn('Pattern argument may not receive a correct input!')
            lens[i] = pattern_len

    data = (data, lens) if return_lens else data
    return data


def gen_repetition_data(
    vocab,
    max_seq_len,
    sample_size,
    distr=None,
    pattern_pool_size=None,
    patterns=None,
    rep_l=11,
    rep_h=20,
    num_repeat=2,
    return_lens=False,
):

    vocab_size = vocab.size(0)
    p = torch.ones(vocab_size) / vocab_size if distr is None else distr

    data = torch.multinomial(p, sample_size * max_seq_len, replacement=True).view(
        sample_size, max_seq_len
    )
    lens = np.zeros(sample_size, dtype=int)
    starts = np.zeros((sample_size, num_repeat), dtype=int)

    if pattern_pool_size is not None and patterns is None:
        # given the size of pattern pool, sample patterns from distribution p with length uniformly drawn from rep_l and repl_h
        patterns = []
        pattern_len_all = np.random.randint(
            low=rep_l, high=rep_h, size=pattern_pool_size
        )
        for t in range(pattern_pool_size):
            pattern = torch.multinomial(p, pattern_len_all[t], replacement=True)
            patterns.append(pattern)

    for i in range(sample_size):
        if pattern_pool_size is None and patterns is None:
            pattern_len = np.random.randint(low=rep_l, high=rep_h)
            pattern_sample = torch.multinomial(p, pattern_len, replacement=True)
        else:
            pattern_sample = patterns[np.random.randint(low=0, high=pattern_pool_size)]
            pattern_len = len(pattern_sample)

        r = max_seq_len - pattern_len * num_repeat
        gaps = torch.multinomial(torch.ones(r) / r, num_repeat, replacement=False)
        gaps = torch.sort(gaps)[0]
        gaps = torch.cat(
            (
                gaps[:1],
                torch.tensor([gaps[i] - gaps[i - 1] for i in range(1, num_repeat)]),
            )
        )
        start_pos = 0
        for j in range(num_repeat):
            start_pos = start_pos + gaps[j]
            data[i, start_pos : (start_pos + pattern_len)] = pattern_sample
            starts[i, j] = start_pos
            start_pos = start_pos + pattern_len
        lens[i] = pattern_len

    data = (data, lens, starts, patterns) if return_lens else data
    return data


def gen_mod_add_data(
    vocab,
    max_seq_len,
    sample_size,
    distr=None,
    pattern_pool_size=None,
    patterns=None,
    rep_l=11,
    rep_h=20,
    num_repeat=2,
    return_lens=False,
):

    vocab_size = vocab.size(0)
    p = torch.ones(vocab_size) / vocab_size if distr is None else distr

    data = torch.multinomial(p, sample_size * max_seq_len, replacement=True).view(
        sample_size, max_seq_len
    )
    lens = np.zeros(sample_size, dtype=int)
    starts = np.zeros((sample_size, num_repeat), dtype=int)

    if pattern_pool_size is not None and patterns is None:
        # given the size of pattern pool, sample patterns from distribution p with length uniformly drawn from rep_l and repl_h
        patterns = []
        pattern_len_all = np.random.randint(
            low=rep_l, high=rep_h, size=pattern_pool_size
        )
        for t in range(pattern_pool_size):
            pattern = torch.multinomial(p, pattern_len_all[t], replacement=True)
            patterns.append(pattern)

    for i in range(sample_size):
        if pattern_pool_size is None and patterns is None:
            pattern_len = np.random.randint(low=rep_l, high=rep_h)
            pattern_sample = torch.multinomial(p, pattern_len, replacement=True)
        else:
            pattern_sample = patterns[np.random.randint(low=0, high=pattern_pool_size)]
            pattern_len = len(pattern_sample)

        r = max_seq_len - pattern_len * num_repeat
        gaps = torch.multinomial(torch.ones(r) / r, num_repeat, replacement=False)
        gaps = torch.sort(gaps)[0]
        gaps = torch.cat(
            (
                gaps[:1],
                torch.tensor([gaps[i] - gaps[i - 1] for i in range(1, num_repeat)]),
            )
        )
        start_pos = 0
        for j in range(num_repeat):
            # start_pos = start_pos + gaps[j] comment this line for deterministic pattern location
            data[i, start_pos : (start_pos + pattern_len)] = pattern_sample # if j == 0 else torch.cumsum(pattern_sample, dim=0) % vocab_size
            starts[i, j] = start_pos
            start_pos = start_pos + pattern_len
        lens[i] = pattern_len

    data = (data, lens, starts, patterns) if return_lens else data
    return data


def gen_simple_Aa_data(
    vocab,
    max_seq_len,
    sample_size,
    pattern=None,
    pattern_sample_len=None,
    rep_l=11,
    rep_h=20,
    return_lens=False,
):
    """
    Generate simple repetitions of certain short-ranged patterns. Each character has two versions (i.e., capitalization or not), sampled randomly
    Args:
        vocab: 1d torch.Tensor containing entire vocabulary
        max_seq_len: positive integer that specifies the maximum number of tokens in a sequence
        sample_size: the number of input sequences
        pattern: a string that indicated the short pattern used for generating sequence data, or a 1-d numpy array; if None, a random pattern will be sampled
        pattern_sample_len: the length of sampled patterns, only used when pattern ='random'
        return_lens: if True, returns repetition length for each seq
    Returns:
        data: input sequences, 2d torch.Tensor of type torch.long; if return_len, returns (data, lens)
    """
    vocab_size = vocab.size(0)
    vocab_halfsize = vocab_size // 2
    data = torch.zeros(sample_size, max_seq_len).type(torch.LongTensor)
    lens = np.zeros(sample_size, dtype=int)

    for i in range(sample_size):
        if pattern is None:
            if pattern_sample_len is None:
                pattern_len = np.random.randint(low=rep_l, high=rep_h)
            pattern_sample = torch.multinomial(
                torch.ones(vocab_halfsize) / vocab_halfsize,
                pattern_len,
                replacement=True,
            )
            num_repeat = (max_seq_len - 1) // pattern_len + 1
            r = pattern_len * num_repeat - max_seq_len
            tmp = torch.zeros(num_repeat * pattern_len)
            for j in range(num_repeat):
                is_upper_case = torch.bernoulli(torch.tensor([0.5])).long()
                tmp[(j * pattern_len) : ((j + 1) * pattern_len)] = vocab[
                    pattern_sample + is_upper_case * vocab_halfsize
                ]
            start_pos = np.random.randint(low=0, high=r + 1)
            data[i, :] = tmp[start_pos : (start_pos + max_seq_len)]
            lens[i] = pattern_len
        else:  # for a given pattern
            pattern_len = len(pattern)
            num_repeat = (max_seq_len - 1) // pattern_len + 1
            r = pattern_len * num_repeat - max_seq_len
            tmp = torch.zeros(num_repeat * pattern_len)
            for j in range(num_repeat):
                is_upper_case = torch.bernoulli(torch.tensor([0.5])).long()
                tmp[(j * pattern_len) : ((j + 1) * pattern_len)] = vocab[
                    pattern + is_upper_case * vocab_halfsize
                ]
            start_pos = np.random.randint(low=0, high=r + 1)
            data[i, :] = tmp[start_pos : (start_pos + max_seq_len)]
            lens[i] = pattern_len

    data = (data, lens) if return_lens else data
    return data


def gen_hmm_data(
    vocab,
    max_seq_len,
    sample_size,
    state_sizes,
    transition_mat=None,
    sig=1.5,
    ioi=False,
    special_tokens=None,
    return_states=True,
    return_tokens_rep=False,
):
    """
    Generate simple HMM data: (z_1,z_2,...,z_T) is a latent Markov chain with z_t in {1,2,...,K}
    For each latent state z_t, observation is uniformly sampled from a set O_{z_t}
    The set O_1, O_2, ... O_K are non-overlapping and their cardinality is given by state_sizes
    sum_k state_sizes[k] must be no larger than vocab size.
    When ioi is True, we either sample two tokens [A] [B] from O_1 (if special_tokens is None) or choose special_tokens,
    and then set observables from O_1 to just repetition [A] [B] [A] [B] ...
    Args:
        vocab: 1d torch.Tensor containing entire vocabulary
        max_seq_len: positive integer that specifies the maximum number of tokens in a sequence
        sample_size: the number of input sequences
        state_sizes: a list/numpy array of integers indicating the size of sets for observables
        transition_mat: a K-by-K transition matrix for the latent Markov chain; if none, a tran_mat will be generated
        sig: the parameter used for generating a transition matrix if transition_mat is not provided
        ioi: if True, use ioi the scheme to sample two tokens from O_1 and repeat; otherwise uniformly sample tokens from O_1 independently
        special_tokens: a 1d torch array of length 2. If not None, use the two tokens for generating IOI
        return_states: if True, returns latent states
        return_tokens_rep: if True, returns tokens that are being repeated for each seq
    Returns:
        data: a list containing variables for the HMM,
        including input sequences, 2d torch.Tensor of type torch.long;
        transition_mat (useful when it is being generated in the function);
        if return_states is True, the latent states for each seq;
        if return_tokens_rep is True, for IOI also return the tokens being repeated
    """
    # check input arguments
    vocab_size = vocab.size(0)
    K = len(state_sizes)
    K_total = np.sum(state_sizes)
    size_cum = np.concatenate(([0], np.cumsum(state_sizes)))
    assert (
        np.all(state_sizes > 0) and K_total <= vocab_size
    ), "Wrong input for state_sizes"
    if transition_mat is not None:
        m1, m2 = transition_mat.shape
        assert (m1 == m2) and (
            m1 == K
        ), "Incorrect input dimension of transition matrix"
        assert torch.all(transition_mat >= 0) and torch.all(
            torch.abs(transition_mat.sum(dim=1) - 1) < 1e-6
        ), "Incorrect input of transition matrix"
    else:
        transition_mat = gen_tran_mat(K, 1, sig=sig)
    _, pi = calc_opt_err(transition_mat)  # get equilibrium distribution pi
    pi = torch.Tensor(pi).float()
    if return_tokens_rep:
        assert ioi, "ioi should be set to True"

    states = torch.zeros(sample_size, max_seq_len).type(torch.LongTensor)
    data = torch.zeros(sample_size, max_seq_len).type(torch.LongTensor)
    tokens_rep = torch.zeros(sample_size, 2).type(torch.LongTensor)
    states[:, 0] = torch.multinomial(pi, sample_size, replacement=True)
    for i in range(sample_size):
        size = state_sizes[states[i, 0]]
        data[i, 0] = size_cum[states[i, 0]] + torch.multinomial(
            torch.ones(size) / size, 1
        )
        for j in range(max_seq_len - 1):
            states[i, j + 1] = torch.multinomial(transition_mat[states[i, j], :], 1)
            size = state_sizes[states[i, j + 1]]
            data[i, j + 1] = size_cum[states[i, j + 1]] + torch.multinomial(
                torch.ones(size) / size, 1
            )
        if ioi:
            loc = states[i] == 0  # identify the state 0 for inserting repetition
            to_repeat_len = (loc.sum() + 1) // 2
            if special_tokens is None:
                tokens = torch.multinomial(
                    torch.ones(state_sizes[0]) / state_sizes[0],
                    to_repeat_len,
                    replacement=True,
                )  # sample a sequence to repeat
            else:
                tokens_idx = torch.multinomial(
                    torch.ones(len(special_tokens)) / len(special_tokens),
                    to_repeat_len,
                    replacement=True,
                )
                tokens = special_tokens[tokens_idx]
            data[i, loc] = tokens.repeat(2)[: loc.sum()]
            if return_tokens_rep:  # NOTE: buggy, set this to False
                tokens_rep[i] = tokens

    out = [data, transition_mat]
    if return_states:
        out.append(states)
    if return_tokens_rep:
        out.append(tokens_rep)
    return out


def gen_insert_data(
    vocab,
    max_seq_len,
    sample_size,
    background="random",
    background_random_weight=None,
    pattern="aaa",
    insrt_num_tokens=5,
    insrt_sep=1,
):
    """
    Generate input sequences for training/testing based on different patterns.
    Insert a short pattern in a purely random sequence (if background='random') or all-zero sequence
    Args:
        vocab: 1d torch.Tensor containing entire vocabulary
        max_seq_len: positive integer that specifies the maximum number of tokens in a sequence
        sample_size: the number of input sequences
        background: 'random' produces random sequence background, otherwise all-zero sequences
        background_random_weight: probability weight for generating random background, default is uniform random
        pattern: 'aaa' produces simple repetition pattern, 'random' produces a randomly sampled short pattern
                1d torch.Tensor plants the specified pattern into background, otherwise do nothing
        insrt_num_tokens: the number of tokens being planted
        insrt_sep: the index difference between consecutive tokens that are planted
    Returns:
        data: input sequences, 2d torch.Tensor of type torch.long
        pattern: the torch.Tensor pattern being sampled if pattern='aaa' or 'random'
        pos_arr: 2d torch.Tensor, the indices of tokens planted in the sequences

    """
    assert type(insrt_num_tokens) == int, "insrt_num_tokens must be an odd integer"
    assert insrt_num_tokens % 2 == 1, "insrt_num_tokens must be an odd integer"
    vocab_size = vocab.size(0)
    k = insrt_num_tokens // 2
    if background_random_weight is None:  # uniform random noise in background
        background_random_weight = torch.ones(vocab_size) / vocab_size

    if background == "random":
        data = torch.multinomial(
            background_random_weight.repeat(sample_size, 1),
            max_seq_len,
            replacement=True,
        )  # random background
    else:
        data = torch.zeros(sample_size, max_seq_len).type(torch.LongTensor)

    pos_arr = torch.zeros(sample_size, insrt_num_tokens)

    for i in range(sample_size):
        insrt_pos_center = torch.randint(
            k * insrt_sep, max_seq_len - k * insrt_sep, size=(1,)
        )
        insrt_pos = torch.arange(-k, k + 1) * insrt_sep + insrt_pos_center
        insrt_pos.type(torch.LongTensor)
        pos_arr[i, :] = insrt_pos
        if pattern == "aaa":
            pattern = torch.multinomial(torch.ones(vocab_size), 1).repeat(
                insrt_num_tokens
            )
            data[i, insrt_pos] = vocab[pattern]  # plant a simple repetition patter
        elif pattern == "random":
            pattern = torch.multinomial(
                torch.ones(vocab_size), insrt_num_tokens, replacement=True
            )  # a random pattern
            data[i, insrt_pos] = vocab[pattern]  # planted the pattern sampled earlier
        elif torch.is_tensor(pattern):
            data[i, insrt_pos] = vocab[
                pattern
            ]  # planted the pattern given by the argument
        else:
            pass  # do not plant signal, only has background

    return data, pattern, pos_arr


def gen_markov_data(
    vocab, max_seq_len, sample_size, transition_mat, init_state_dist=None
):
    """
    Generate input sequences for training/testing based on a markov chain.
    The markov chain is sampled according to a given transition matrix
    Args:
        vocab: 1d torch.Tensor containing entire vocabulary
        max_seq_len: positive integer that specifies the maximum number of tokens in a sequence
        sample_size: the number of input sequences
        transition_max: transition matrix for the Markov chain, a 2d torch.Tensor that has the same dimension
                as the vocabulary size, must be a valid transition matrix
        init_state_dist: the initial state distribution, 1d Tensor that sums to one
    Returns:
        data: input sequences, 2d torch.Tensor of type torch.long
    """
    vocab_size = vocab.size(0)
    m1, m2 = transition_mat.shape
    assert (m1 == m2) and (
        m1 == vocab_size
    ), "Incorrect input dimension of transition matrix"
    assert torch.all(transition_mat >= 0) and torch.all(
        torch.abs(transition_mat.sum(dim=1) - 1) < 1e-6
    ), "Incorrect input of transition matrix"
    if init_state_dist is not None:
        assert (
            torch.abs(init_state_dist.sum() - 1) < 1e-6
        ), "Incorrect input of initial state distribution: not summing to one"
    data = torch.zeros(sample_size, max_seq_len).type(torch.LongTensor)
    if (
        init_state_dist is None
    ):  # random initial states at position 0 and 1 for each sequence
        data[:, 0] = torch.randint(0, vocab_size, size=(sample_size,))
    else:  # use the initial state distribution if provided
        states_init = torch.multinomial(init_state_dist, sample_size, replacement=True)
        data[:, 0] = states_init
    for i in range(sample_size):
        for j in range(max_seq_len - 1):
            data[i, j + 1] = torch.multinomial(transition_mat[data[i, j], :], 1)

    return data


def gen_2nd_markov_data(
    vocab, max_seq_len, sample_size, transition_mat, init_state_dist=None
):
    """
    Generate input sequences for training/testing based on a markov chain.
    The markov chain is sampled according to a given transition matrix/tensor
    Args:
        vocab: 1d torch.Tensor containing entire vocabulary
        max_seq_len: positive integer that specifies the maximum number of tokens in a sequence
        sample_size: the number of input sequences
        transition_max: transition matrix for the Markov chain, a 3d torch.Tensor K * K * K where the last dimension is the
                                    output probability; must be a valid transition matrix
        init_state_dist: the initial state distribution, 1d Tensor that sums to one
    Returns:
        data: input sequences, 2d torch.Tensor of type torch.long
    """
    vocab_size = vocab.size(0)
    m1, m2, m3 = transition_mat.shape
    assert (
        (m1 == m2) and (m1 == m3) and (m1 == vocab_size)
    ), "Incorrect input dimension of transition matrix"
    assert torch.all(transition_mat >= 0) and torch.all(
        torch.abs(transition_mat.sum(dim=2) - 1) < 1e-6
    ), "Incorrect input of transition matrix"
    if init_state_dist is not None:
        assert (
            torch.abs(init_state_dist.sum() - 1) < 1e-6
        ), "Incorrect input of initial state distribution: not summing to one"
    data = torch.zeros(sample_size, max_seq_len).type(torch.LongTensor)
    if (
        init_state_dist is None
    ):  # random initial states at position 0 and 1 for each sequence
        data[:, 0] = torch.randint(0, vocab_size, size=(sample_size,))
        data[:, 1] = torch.randint(0, vocab_size, size=(sample_size,))
    else:  # use the initial state distribution if provided
        states_full = torch.multinomial(init_state_dist, sample_size, replacement=True)
        data[:, 0] = states_full // vocab_size
        data[:, 1] = states_full % vocab_size
    for i in range(sample_size):
        for j in range(max_seq_len - 2):
            data[i, j + 2] = torch.multinomial(
                transition_mat[data[i, j], data[i, j + 1], :], 1
            )

    return data


def gen_3rd_markov_data(
    vocab, max_seq_len, sample_size, transition_mat, init_state_dist=None
):
    """
    Generate input sequences for training/testing based on a markov chain.
    The markov chain is sampled according to a given transition matrix/tensor
    Args:
        vocab: 1d torch.Tensor containing entire vocabulary
        max_seq_len: positive integer that specifies the maximum number of tokens in a sequence
        sample_size: the number of input sequences
        transition_max: transition matrix for the Markov chain, a 4d torch.Tensor K * K * K where the last dimension is the
                                    output probability; must be a valid transition matrix
        init_state_dist: the initial state distribution, 1d Tensor that sums to one
    Returns:
        data: input sequences, 3d torch.Tensor of type torch.long
    """
    vocab_size = vocab.size(0)
    m1, m2, m3, m4 = transition_mat.shape
    assert (
        (m1 == m2) and (m1 == m3) and (m1 == m4) and (m1 == vocab_size)
    ), "Incorrect input dimension of transition matrix"
    assert torch.all(transition_mat >= 0) and torch.all(
        torch.abs(transition_mat.sum(dim=3) - 1) < 1e-6
    ), "Incorrect input of transition matrix"
    if init_state_dist is not None:
        assert (
            torch.abs(init_state_dist.sum() - 1) < 1e-6
        ), "Incorrect input of initial state distribution: not summing to one"
    data = torch.zeros(sample_size, max_seq_len).type(torch.LongTensor)
    if (
        init_state_dist is None
    ):  # random initial states at position 0 and 1 for each sequence
        data[:, 0] = torch.randint(0, vocab_size, size=(sample_size,))
        data[:, 1] = torch.randint(0, vocab_size, size=(sample_size,))
        data[:, 2] = torch.randint(0, vocab_size, size=(sample_size,))
    else:  # use the initial state distribution if provided
        states_full = torch.multinomial(init_state_dist, sample_size, replacement=True)
        s0, s12 = states_full // (vocab_size**2), states_full % (vocab_size**2)
        data[:, 0] = s0
        data[:, 1] = s12 // vocab_size
        data[:, 2] = s12 % vocab_size
    for i in range(sample_size):
        for j in range(max_seq_len - 3):
            data[i, j + 3] = torch.multinomial(
                transition_mat[data[i, j], data[i, j + 1], data[i, j + 2], :], 1
            )

    return data


def gen_mixed1st_markov_data(
    vocab,
    max_seq_len,
    sample_size,
    transition_mat,
    transition_mat2,
    init_state_dist=None,
):
    """
    Generate input sequences for training/testing based on two markov chains.
    Both markov chains are sampled according to the associated given transition matrices
    Args:
        vocab: 1d torch.Tensor containing entire vocabulary
        max_seq_len: positive integer that specifies the maximum number of tokens in a sequence
        sample_size: the number of input sequences
        transition_max: transition matrix for the Markov chain, a 2d torch.Tensor K * K where the second dimension is the
                                    output probability; must be a valid transition matrix
        transition_mat2: similar to transition_max, used to model long-range dependence
        init_state_dist: the initial state distribution, 1d Tensor that sums to one
    Returns:
        data: input sequences, 2d torch.Tensor of type torch.long
    """
    vocab_size = vocab.size(0)
    m1, m2 = transition_mat.shape
    assert (m1 == m2) and (
        m1 == vocab_size
    ), "Incorrect input dimension of transition matrix"
    assert torch.all(transition_mat >= 0) and torch.all(
        torch.abs(transition_mat.sum(dim=1) - 1) < 1e-6
    ), "Incorrect input of transition matrix"
    m1, m2 = transition_mat2.shape
    assert (m1 == m2) and (
        m1 == vocab_size
    ), "Incorrect input dimension of transition matrix"
    assert torch.all(transition_mat2 >= 0) and torch.all(
        torch.abs(transition_mat2.sum(dim=1) - 1) < 1e-6
    ), "Incorrect input of transition matrix"
    if init_state_dist is not None:
        assert (
            torch.abs(init_state_dist.sum() - 1) < 1e-6
        ), "Incorrect input of initial state distribution: not summing to one"

    data = torch.zeros(sample_size, max_seq_len).type(torch.LongTensor)
    if (
        init_state_dist is None
    ):  # random initial states at position 0 and 1 for each sequence
        data[:, 0] = torch.randint(0, vocab_size, size=(sample_size,))
    else:  # use the initial state distribution if provided
        states_init = torch.multinomial(init_state_dist, sample_size, replacement=True)
        data[:, 0] = states_init
    for i in range(sample_size):
        for j in range(max_seq_len - 1):
            a = torch.bernoulli(torch.Tensor([0.3]))
            token = torch.multinomial(transition_mat[data[i, j], :], 1)
            token2 = (
                torch.multinomial(transition_mat2[data[i, j - 9], :], 1)
                if j > 8
                else token
            )
            data[i, j + 1] = a * token + (1 - a) * token2

    return data


def gen_mixed2nd_markov_data(
    vocab,
    max_seq_len,
    sample_size,
    transition_mat,
    transition_mat2,
    init_state_dist=None,
):
    """
    Generate input sequences for training/testing based on two markov chains.
    Both markov chains are sampled according to the associated given transition matrix/tensor
    Args:
        vocab: 1d torch.Tensor containing entire vocabulary
        max_seq_len: positive integer that specifies the maximum number of tokens in a sequence
        sample_size: the number of input sequences
        transition_max: transition matrix for the Markov chain, a 3d torch.Tensor K * K * K where the last dimension is the
                                    output probability; must be a valid transition matrix
        transition_mat2: similar to transition_max, used to model long-range dependence
        init_state_dist: the initial state distribution, 1d Tensor that sums to one
    Returns:
        data: input sequences, 2d torch.Tensor of type torch.long
    """
    vocab_size = vocab.size(0)
    m1, m2, m3 = transition_mat.shape
    assert (
        (m1 == m2) and (m1 == m3) and (m1 == vocab_size)
    ), "Incorrect input dimension of transition matrix"
    assert torch.all(transition_mat >= 0) and torch.all(
        torch.abs(transition_mat.sum(dim=2) - 1) < 1e-6
    ), "Incorrect input of transition matrix"
    m1, m2, m3 = transition_mat2.shape
    assert (
        (m1 == m2) and (m1 == m3) and (m1 == vocab_size)
    ), "Incorrect input dimension of transition matrix"
    assert torch.all(transition_mat2 >= 0) and torch.all(
        torch.abs(transition_mat2.sum(dim=2) - 1) < 1e-6
    ), "Incorrect input of transition matrix"
    if init_state_dist is not None:
        assert (
            torch.abs(init_state_dist.sum() - 1) < 1e-6
        ), "Incorrect input of initial state distribution: not summing to one"

    data = torch.zeros(sample_size, max_seq_len).type(torch.LongTensor)
    if (
        init_state_dist is None
    ):  # random initial states at position 0 and 1 for each sequence
        data[:, 0] = torch.randint(0, vocab_size, size=(sample_size,))
        data[:, 1] = torch.randint(0, vocab_size, size=(sample_size,))
    else:  # use the initial state distribution if provided
        states_full = torch.multinomial(init_state_dist, sample_size, replacement=True)
        data[:, 0] = states_full // vocab_size
        data[:, 1] = states_full % vocab_size
    for i in range(sample_size):
        for j in range(max_seq_len - 2):
            a = torch.bernoulli(torch.Tensor([1 / 2]))
            token = torch.multinomial(transition_mat[data[i, j], data[i, j + 1], :], 1)
            token2 = (
                torch.multinomial(transition_mat2[data[i, j - 9], data[i, j - 8], :], 1)
                if j > 8
                else token
            )
            data[i, j + 2] = a * token + (1 - a) * token2

    return data


def gen_higher_markov_data(
    vocab, max_seq_len, sample_size, transition_mat, init_state_dist=None
):
    """
    Generate input sequences for training/testing based on a higher markov chain,
    which is constructed from a first-order transition probability matrix
    Args:
        vocab: 1d torch.Tensor containing entire vocabulary
        max_seq_len: positive integer that specifies the maximum number of tokens in a sequence
        sample_size: the number of input sequences
        transition_max: transition matrix for the Markov chain, a 2d torch.Tensor K * K where the second dimension is the
                                    output probability; must be a valid transition matrix
        init_state_dist: the initial state distribution, 1d Tensor that sums to one
    Returns:
        data: input sequences, 2d torch.Tensor of type torch.long
    """
    vocab_size = vocab.size(0)
    m1, m2 = transition_mat.shape
    assert (m1 == m2) and (
        m1 == vocab_size
    ), "Incorrect input dimension of transition matrix"
    assert torch.all(transition_mat >= 0) and torch.all(
        torch.abs(transition_mat.sum(dim=1) - 1) < 1e-6
    ), "Incorrect input of transition matrix"
    data = torch.zeros(sample_size, max_seq_len).type(torch.LongTensor)

    if (
        init_state_dist is None
    ):  # random initial states at position 0 and 1 for each sequence
        data[:, 0] = torch.randint(0, vocab_size, size=(sample_size,))
    else:
        states_init = torch.multinomial(init_state_dist, sample_size, replacement=True)
        data[:, 0] = states_init
    for i in range(sample_size):
        for j in range(max_seq_len - 1):
            if j < 4:
                data[i, j + 1] = torch.multinomial(transition_mat[data[i, j], :], 1)
            else:
                vec = torch.Tensor(
                    [
                        torch.mean((data[i, range(j - 4, j + 1)] == k) + 0.0)
                        for k in range(vocab_size)
                    ]
                )
                probs = torch.matmul(transition_mat.T, vec)
                data[i, j + 1] = torch.multinomial(probs, 1)

    return data


def gen_binary_random_pattern(vocab, max_seq_len, sample_size, probs=1 / 2):
    """
    Generate two types of patterns randomly, which are mixed in the inputs
    """
    vocab_size = vocab.size(0)
    assert vocab_size >= 6, "Need vocabulary size no smaller than 7"

    tmp = torch.bernoulli(probs * torch.ones(sample_size, max_seq_len // 4)).type(
        torch.bool
    )
    tmp = torch.Tensor(np.repeat(tmp.numpy(), 4, axis=1)).bool()
    seq1 = (
        torch.Tensor([0, 2, 3, 5])
        .repeat(sample_size, max_seq_len // 4)
        .type(torch.long)
    )
    seq2 = (
        torch.Tensor([1, 2, 3, 4])
        .repeat(sample_size, max_seq_len // 4)
        .type(torch.long)
    )
    data = torch.zeros(sample_size, max_seq_len).long()
    data[tmp] = seq1[tmp]
    data[~tmp] = seq2[~tmp]

    return data


def gen_mixed_higher_markov_data(
    vocab,
    max_seq_len,
    sample_size,
    transition_mats=None,
    max_order=2,
    order_freq=None,
    sig_param=None,
    delimiter_freq=None,
):
    vocab_size = vocab.size(0)
    delimiter_index_high = (
        8  # insert delimiter after we generate at most break_index_high tokens
    )
    delimiter_index_low = 5
    if order_freq is None:
        order_freq = torch.ones(max_order) / max_order
    if sig_param is None:
        sig_param = torch.arange(max_order) + 1
    if delimiter_freq is None:
        freq = torch.ones(delimiter_index_high)
        freq[: (delimiter_index_low + 1)] = 0
        delimiter_freq = freq / freq.sum()

    if transition_mats is None:  # generate transition matrix/tensor if no provided
        transition_mats = [
            torch.zeros(tuple([vocab_size - 1] * mc_order))
            for mc_order in range(1, max_order + 1)
        ]
        for mc_order in range(1, max_order + 1):
            mat = torch.exp(
                sig_param[mc_order - 1]
                * torch.randn(tuple([vocab_size - 1] * (mc_order + 1)))
            )
            transition_mats[mc_order - 1] = mat / mat.sum(dim=-1, keepdim=True)

    data = torch.zeros(sample_size, max_seq_len).long()
    # uniform random initial states
    data[:, :max_order] = torch.randint(
        0, vocab_size - 1, size=(sample_size, max_order)
    )
    # sample next-token sequentially
    for i in range(sample_size):
        delimiter_index = torch.multinomial(delimiter_freq, 1)
        mc_order = torch.multinomial(order_freq, 1) + 1
        mat = transition_mats[mc_order - 1]
        counter = 0
        for j in range(max_seq_len - 1):
            if counter <= delimiter_index:
                counter += 1
                states = data[
                    i, range(j + 1 - mc_order, j + 1)
                ].numpy()  # use the previous mc_order states to sample next token
                if (
                    vocab_size - 1 in states
                ):  # if delimiter is in the states, look past to get one more token and remove delimiter
                    states = data[i, range(j - mc_order, j + 1)].numpy()
                    states = states[states != vocab_size - 1]
                data[i, j + 1] = torch.multinomial(mat[tuple(states)], 1)
            else:  # reach the index for delimiter, resample index and mc order, and add a delimiter token
                delimiter_index = torch.multinomial(delimiter_freq, 1)
                mc_order = torch.multinomial(order_freq, 1) + 1
                mat = transition_mats[mc_order - 1]
                counter = 0
                data[i, j + 1] = torch.tensor(
                    [vocab_size - 1]
                ).long()  # add a delimiter token (index is vocab_size-1)
    return data, transition_mats


def gen_mixed_delimited_markov_data(
    vocab,
    max_seq_len,
    sample_size,
    components=3,
    transition_mats=None,
    max_order=2,
    component_freq=None,
    sig_param=None,
    delimiter_freq=None,
):
    vocab_size = vocab.size(0)
    delimiter_index_high = (
        8  # insert delimiter after we generate at most break_index_high tokens
    )
    delimiter_index_low = 5
    if component_freq is None:
        component_freq = torch.ones(components) / components
    if sig_param is None:
        sig_param = torch.arange(components) + 1
    if delimiter_freq is None:
        freq = torch.ones(delimiter_index_high)
        freq[: (delimiter_index_low + 1)] = 0
        delimiter_freq = freq / freq.sum()

    if transition_mats is None:  # generate transition matrix/tensor if no provided
        transition_mats = [
            torch.zeros(tuple([vocab_size - 1] * max_order)) for k in range(components)
        ]
        for k in range(components):
            mat = torch.exp(
                sig_param[k] * torch.randn(tuple([vocab_size - 1] * (max_order + 1)))
            )
            transition_mats[k] = mat / mat.sum(dim=-1, keepdim=True)

    data = torch.zeros(sample_size, max_seq_len).long()
    # uniform random initial states
    data[:, :max_order] = torch.randint(
        0, vocab_size - 1, size=(sample_size, max_order)
    )
    # sample next-token sequentially
    for i in range(sample_size):
        delimiter_index = torch.multinomial(delimiter_freq, 1)
        k = torch.multinomial(component_freq, 1)
        mat = transition_mats[k]
        counter = 0
        for j in range(max_seq_len - 1):
            if counter <= delimiter_index:
                counter += 1
                states = data[
                    i, range(j + 1 - max_order, j + 1)
                ].numpy()  # use the previous max_order states to sample next token
                if (
                    vocab_size - 1 in states
                ):  # if delimiter is in the states, just do random sampling
                    data[i, j + 1] = torch.randint(0, vocab_size - 1, size=(1,))
            else:  # reach the index for delimiter, resample index and mc order, and add a delimiter token
                delimiter_index = torch.multinomial(delimiter_freq, 1)
                k = torch.multinomial(component_freq, 1)
                mat = transition_mats[k]
                counter = 0
                data[i, j + 1] = torch.tensor(
                    [vocab_size - 1]
                ).long()  # add a delimiter token (index is vocab_size-1)
    return data, transition_mats


# Training Configuration
Set up training hyperparameters, and optimization settings including SAM optimizer.

In [38]:
def get_mask(src, lens, starts=None, ignore_segment=0, ignore_burning=0):
    M = torch.ones_like(src)
    if lens is not None and starts is None:
        M = torch.Tensor(
            mask_get_along_axis(
                src.shape,
                lens,
                ignore_segment=ignore_segment,
                ignore_burning=ignore_burning,
            )
        )
    elif lens is not None and starts is not None:
        M = torch.Tensor(
            mask_get_given_starts(
                src.shape,
                lens,
                starts,
                ignore_segment=ignore_segment,
                ignore_burning=ignore_burning,
            )
        )
    return M


def get_loss(model, criterion, src):
    output = model(src)
    vocab_size = output.size(-1)
    loss = criterion(
        output[:, :-1].contiguous().view(-1, vocab_size),
        src[:, 1:].contiguous().view(-1),
    )
    return loss

@torch.no_grad()
def loss_err(model, criterion, src, mask):
    model.eval()
    output = model(src)
    vocab_size = output.size(-1)
    loss = criterion(
        output[:, :-1].contiguous().view(-1, vocab_size),
        src[:, 1:].contiguous().view(-1),
    )

    tmp = output.argmax(dim=2)[:, :-1] == src[:, 1:]
    err = 1 - torch.sum(tmp.cpu() * mask[:, :-1], dtype=torch.float) / torch.sum(mask)
    return loss, err


In [39]:
def gen_simulated_data(
    distr,
    vocab,
    max_seq_len,
    sample_size,
    regime,
    pool_size,
    patterns,
    rep_l,
    rep_h,
    device,
):
    if regime == "simple repetition":
        src, lens = gen_simple_data(
            vocab,
            max_seq_len,
            sample_size,
            return_lens=True,
            rep_l=rep_l,
            rep_h=rep_h,
        )

        return src.to(device), lens, None, None

    elif regime == "varied repetition":
        src, lens, starts, patterns = gen_repetition_data(
            vocab,
            max_seq_len,
            sample_size,
            distr=distr,
            pattern_pool_size=pool_size,
            patterns=patterns,
            return_lens=True,
            rep_l=rep_l,
            rep_h=rep_h,
        )

        return src.to(device), lens, starts, patterns
    
    elif regime == "modular addition":
        src, lens, starts, patterns = gen_mod_add_data(
            vocab,
            max_seq_len,
            sample_size,
            distr=distr,
            pattern_pool_size=pool_size,
            patterns=patterns,
            return_lens=True,
            rep_l=rep_l,
            rep_h=rep_h,
        )

        return src.to(device), lens, starts, patterns

In [40]:
def make_distr(config):
    if config.distr == "two-level":
        p = np.concatenate(
            (
                np.array([1 / 8] * 4),
                np.array([1 / (2 * (config.vocab_size - 4))] * (config.vocab_size - 4)),
            )
        )
        # np.random.shuffle(p)
        p = torch.Tensor(p)
    elif config.distr == "two-level-3":  # NOT USED for now, may change later
        p = np.concatenate(
            (
                np.array([1 / 8] * 4),
                np.array([1 / (2 * (config.vocab_size - 4))] * (config.vocab_size - 4)),
            )
        )
        # np.random.shuffle(p)
        p = torch.Tensor(p)
    elif config.distr == "zipf":
        # https://en.wikipedia.org/wiki/Zipf%27s_law
        p = np.array([1 / (i + 2.7) for i in range(1, config.vocab_size + 1)])
        p = p / np.sum(p)
        # np.random.shuffle(p)
        p = torch.Tensor(p)
    elif config.distr == "unif":
        p = None
    else:
        raise ValueError(f"distr {config.distr} is not supported!")

    return p

In [41]:
def train_infinite(
    model,
    config,
    optimizer,
    scheduler,
    use_sam = False,
):
    num_epoch = config.num_epoch
    batch_size = config.batch_size
    vocab = torch.arange(config.vocab_size).type(torch.LongTensor)
    p = make_distr(config)

    src_test, lens_test, starts_test, patterns = gen_simulated_data(
        distr=p,
        vocab=vocab,
        max_seq_len=config.max_seq_len,
        regime=config.regime,
        sample_size=config.sample_size_test,
        pool_size=config.pool_size,
        patterns=None,
        rep_l=config.rep_l,
        rep_h=config.rep_h,
        device=config.device,
    )

    src_test_ood, lens_test_ood, starts_test_ood, _ = gen_simulated_data(
        distr=None,
        vocab=vocab,
        max_seq_len=config.max_seq_len,
        regime=config.regime,
        sample_size=config.sample_size_test,
        pool_size=None,
        patterns=None,
        rep_l=config.ood_len_pattern,
        rep_h=config.ood_len_pattern + 1,
        device=config.device,
    )

    M_test = get_mask(
        src_test,
        lens_test,
        starts_test,
        ignore_segment=config.ignore_segment,
        ignore_burning=config.ignore_burning,
    )
    M_test_ood = get_mask(
        src_test_ood,
        lens_test_ood,
        starts_test_ood,
        ignore_segment=config.ignore_segment,
        ignore_burning=config.ignore_burning,
    )

    torch.save(
        [src_test, lens_test, starts_test], os.path.join(config.out_dir, "test.pth")
    )
    torch.save(
        [src_test_ood, lens_test_ood, starts_test_ood],
        os.path.join(config.out_dir, "test_ood.pth"),
    )

    err_arr = np.zeros((num_epoch, 6))
    sharpness_arr = np.zeros((num_epoch,))
    trial_sharpness_arr = np.zeros((num_epoch, 2000))
    diff_by_blk_summary = dict()

    err_arr_json = []
    criterion = (
        nn.CrossEntropyLoss(label_smoothing=0.1)
        if config.label_smoothing
        else nn.CrossEntropyLoss()
    )


    train_dataset = []
    for epoch in range(num_epoch):
        src, lens_train, starts_train, _ = gen_simulated_data(
            distr=p,
            vocab=vocab,
            max_seq_len=config.max_seq_len,
            regime=config.regime,
            sample_size=batch_size,
            pool_size=config.pool_size,
            patterns=patterns,
            rep_l=config.rep_l,
            rep_h=config.rep_h,
            device=config.device,
        )
        M = get_mask(
            src,
            lens_train,
            starts_train,
            ignore_segment=config.ignore_segment,
            ignore_burning=config.ignore_burning,
        )
        train_dataset.append((src, lens_train, starts_train, _, M))

    # torch.save(train_dataset, "train_dataset.pt")

    for epoch in tqdm(range(num_epoch)):
        model.train()

        optimizer.zero_grad()

        src, lens_train, starts_train, _, M = train_dataset[epoch]

        '''
        src, lens_train, starts_train, _ = gen_simulated_data(
            distr=p,
            vocab=vocab,
            max_seq_len=config.max_seq_len,
            regime=config.regime,
            sample_size=batch_size,
            pool_size=config.pool_size,
            patterns=patterns,
            rep_l=config.rep_l,
            rep_h=config.rep_h,
            device=config.device,
        )
        M = get_mask(
            src,
            lens_train,
            starts_train,
            ignore_segment=config.ignore_segment,
            ignore_burning=config.ignore_burning,
        )
        '''
        loss = get_loss(model, criterion, src)
        loss.backward()
        
        if use_sam:
            def closure():
                optimizer.zero_grad()
                loss = get_loss(model, criterion, src)
                loss.backward()
                return loss
            loss = optimizer.step(closure)
        else:
            optimizer.step()
        

        with torch.no_grad():
            model.eval()  # useful if dropout or batchnorm etc is turned on
            loss_train, train_err = loss_err(model, criterion, src, M)
            loss_test, test_err = loss_err(model, criterion, src_test, M_test)
            loss_test_ood, test_err_ood = loss_err(
                model, criterion, src_test_ood, M_test_ood
            )
            if False: # compute full Hessian
                hessian_train = get_hessian(model, criterion, src=src, dataset=train_dataset)
                sharpness = torch.trace(hessian_train)
                print(sharpness)


        if False: # compute block-diagonal Hessian
            if (epoch % 100 == 0 and 1000 <= epoch <= 2000) or epoch in [0,700]:
                blkdiag_hessian_train = get_blkdiag_hessian(model, criterion, src=src, dataset=train_dataset)
                avg_sharpness = sum([torch.trace(h) for h in blkdiag_hessian_train])
                blk_spectrums = [torch.linalg.eigh(h)[0] for h in blkdiag_hessian_train]
                # spectrum = torch.concat(blk_spectrums)
                
                # plot spectrum
                parameter_names = [name for name, _ in model.named_parameters()]
                plot_blk_spectrum(
                   blk_spectrums, 
                   parameter_names, 
                   fig_name=f"spectrum_epoch_{epoch}", 
                   save_dir=config.out_dir
                )

                import matplotlib.pyplot as plt
                spectrum = torch.concat(blk_spectrums)
                plt.figure()
                plt.hist(spectrum, 100)
                plt.yscale('log')
                plt.savefig(os.path.join(config.out_dir, f"spectrum_hist_epoch_{epoch}"))

        if False: # directly compute Hessian trace
            if epoch % 1000 == 0: # (epoch % 100 == 0 and 1000 <= epoch <= 2000) or epoch in [0,700]:
                sharpness_trace = get_trace_hessian(model, criterion, src=src, dataset=train_dataset)
                avg_sharpness = sum(sharpness_trace) / len(sharpness_trace)
                
                # scale the sharpness by model weight
                # avg_sharpness *= sum([torch.norm(p).item()**2 for p in model.parameters()])

        if False:
            if epoch % config.sharpness_step == 0:
                avg_sharpness, diff_by_blk = get_robustness_blk(model, criterion, src=src, dataset=train_dataset, num_perturb=100, r_perturb=1e-3, data_sample_size=20, config=config)

        if False:
            if epoch % config.sharpness_step == 0:
                diff = get_robustness(model, criterion, src=src, dataset=train_dataset, num_perturb=100, r_perturb=1e-3, data_sample_size=20, config=config)
                avg_sharpness = sum(diff) / len(diff)

        if config.sharpness_task == "outer-product-Hessian":
            if epoch % config.sharpness_step == 0 and epoch > 9997:
                H_out = get_outer_product_hess(model, criterion, src=src, dataset=train_dataset)
                H = get_blkdiag_hessian(model, criterion, src=src, dataset=train_dataset)
                torch.save((H_out, H), f"out/out-hess-{epoch}.pt")

        if config.sharpness_task == "outer-product-Hessian-decompose":
            if epoch % config.sharpness_step == 0 and epoch > 5997:
                H_out = get_outer_product_hess_decompose(model, criterion, src=src, dataset=train_dataset)
                H = get_blkdiag_hessian(model, criterion, src=src, dataset=train_dataset)
                torch.save((H_out, H), f"out/out-hess-decompose-{epoch}.pt")

        if config.sharpness_task == "outer-product-Hessian-random-alignment":
            if 2 < epoch < 10:
                # random model
                state_dict = model.state_dict().copy()
                for name in model.state_dict():
                    state_dict[name] = torch.randn_like(model.state_dict()[name]) / math.sqrt(config.d_model)
                model.load_state_dict(state_dict)
                H_out = get_outer_product_hess_decompose(model, criterion, src=src, dataset=train_dataset)
                H = get_blkdiag_hessian(model, criterion, src=src, dataset=train_dataset)
                torch.save((H_out, H), f"out/out-hess-random-{epoch}.pt")

                # aligned model
                state_dict = model.state_dict().copy()
                for name in model.state_dict():
                    if name not in ['h.1.mha.W_q.weight', 'h.1.mha.W_k.weight']:
                        state_dict[name] = torch.randn_like(model.state_dict()[name]) / math.sqrt(config.d_model)
                    else:
                        if name == 'h.1.mha.W_q.weight':
                            rot, _ = torch.linalg.qr(torch.randn_like(model.state_dict()[name]))
                            state_dict[name] = torch.linalg.inv(state_dict['h.0.mha.W_o.weight'] @ state_dict['h.0.mha.W_v.weight'].T) @ rot
                        else:
                            state_dict[name] = rot
                model.load_state_dict(state_dict)
                H_out = get_outer_product_hess_decompose(model, criterion, src=src, dataset=train_dataset)
                H = get_blkdiag_hessian(model, criterion, src=src, dataset=train_dataset)
                torch.save((H_out, H), f"out/out-hess-align-{epoch}.pt")
            elif epoch >= 10:
                exit()


        '''
        if len(diff_by_blk_summary) == 0:
            for k in diff_by_blk.keys():
                diff_by_blk_summary[k] = [diff_by_blk[k].item()]
        else:
            for k in diff_by_blk.keys():
                diff_by_blk_summary[k].append(diff_by_blk[k].item())
        '''

        #sharpness_arr[epoch] = avg_sharpness
        #trial_sharpness_arr[epoch] = np.array([d.item() for d in diff])

        err_arr[epoch, :] = [
            loss_train.item(),
            train_err.item(),
            loss_test.item(),
            test_err.item(),
            loss_test_ood.item(),
            test_err_ood.item(),
        ]

        err_arr_json += [
            {
                "epoch": epoch,
                "loss_train": loss_train.item(),
                "err_train": train_err.item(),
                "loss_test": loss_test.item(),
                "err_test": test_err.item(),
                "loss_ood": loss_test_ood.item(),
                "err_ood": test_err_ood.item(),
            }
        ]

        scheduler.step()

        if epoch % config.plot_attn_every_epoch == 0 and err_arr[epoch, 5] > 0.05:
            plots_maker(
                model,
                config,
                [src, src_test, src_test_ood],
                epoch=epoch,
                lens=[lens_train, lens_test, lens_test_ood],
                starts=[starts_train, starts_test, starts_test_ood],
                save_dir=os.path.join(config.out_dir, "figures"),
            )

            if config.print_output:
                print(
                    f"----> Epoch: {epoch+1:>5}, Train Loss: {loss.item():.3f}, Test Error: {err_arr[epoch,3]:.3f}, OOD Error: {err_arr[epoch,5]:.3f}"
                )

        if (1 + epoch) % (config.num_epoch // config.n_save) == 0 or (
            config.up_to_first_save
            and (1 + epoch)
            in [
                np.power(2, k)
                for k in range(int(np.log2(config.num_epoch // config.n_save)))
            ]
        ):
            out_path = os.path.join(config.out_dir, f"ckpt_{epoch + 1}.pt")
            torch.save(model.state_dict(), out_path)

    lens = [lens_train, lens_test, lens_test_ood]
    _ = plot_err_over_pos(
        model,
        [src, src_test, src_test_ood],
        config.vocab_size,
        "err_over_pos",
        lens=lens,
        starts=[starts_train, starts_test, starts_test_ood],
        src_labels=["train", "test", "ood"],
        save_dir=config.out_dir,
    )

    # np.save("out/trial_diff.npy", trial_sharpness_arr)

    return model, err_arr, err_arr_json



In [42]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

In [43]:
#####################################################
################# Simple utitlity function  ###################
#####################################################


def create_folder(dir):
    if not os.path.isdir(dir):
        os.makedirs(dir, exist_ok=True)


def fix_random_seed(seed, reproduce=False):
    # cudnn.enabled = True
    # cudnn.benchmark = True

    if reproduce:
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        ## NOTE: uncomment for CUDA >= 10.2
        # os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        ## NOTE: uncomment for pytorch >= 1.8
        # torch.use_deterministic_algorithms(True)

    # os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    rng = torch.manual_seed(seed)

    return rng


#####################################################


def gen_tran_mat(vocab_size, order, sig=1, sparsity=None):
    mat = torch.exp(sig * torch.randn(tuple([vocab_size] * (order + 1))))
    mat = mat / mat.sum(dim=-1, keepdim=True)
    if sparsity is not None:
        cutoff = torch.quantile(mat.flatten(), 1 - sparsity)
        mat[mat < cutoff] = 0
        mat = mat / mat.sum(dim=-1, keepdim=True)
    return mat


def calc_opt_err(mat):
    """
    Given a transition probability matrix in a Markov chain, calculate the optimal achievable error
    Args:
        mat: the transition probability matrix, will check its validity
    Returns:
        err_opt: scalar, the optimal error under equilibrium distribution
        pi: 1d array, equilibrium distribution
    """

    m = mat.size(0)
    m1, m2 = mat.shape
    assert m1 == m2, "Incorrect input dimension of transition matrix"
    assert torch.all(mat >= 0) and torch.all(
        torch.abs(mat.sum(dim=1) - 1) < 1e-6
    ), "Incorrect input of transition matrix"

    vals, vecs = np.linalg.eig(mat.numpy().T)
    idx = np.argsort(vals)
    pi = np.real(vecs[:, idx[-1]])  # equilibrium distribution
    pi = pi / np.sum(pi)  # don't forget to normalize so that it sums to one
    err_opt = np.dot(pi, 1 - mat.max(dim=1)[0].numpy())

    return err_opt, pi


def get_mat_full(mat, order=2):
    """
    For second-order or third-order markov chains, get_mat_full will
    1) if order=2, convert the transition matrix from the tensor form K*K*K to the matrix form (K^2) * (K^2)
    2) if order=3, convert the transition matrix from the tensor form K*K*K*K to the matrix form (K^3) * (K^3)
    Args:
        mat: the transition probability tensor, will check its validity
    Returns:
        mat_full:  transition probability matrix
    """
    if order == 2:
        m1, m2, m3 = mat.shape
        vocab_size = m1
        assert (m1 == m2) and (
            m1 == m3
        ), "Incorrect input dimension of transition matrix"
        assert torch.all(mat >= 0) and torch.all(
            torch.abs(mat.sum(dim=2) - 1) < 1e-6
        ), "Incorrect input of transition matrix"
        mat_full = torch.zeros(vocab_size**2, vocab_size**2).float()
        for k1 in range(vocab_size):
            for k2 in range(vocab_size):
                k = k1 * vocab_size + k2
                k_out = k2 * vocab_size + torch.arange(vocab_size).long()
                mat_full[k, k_out] = mat[k1, k2, :]
    elif order == 3:
        m1, m2, m3, m4 = mat.shape
        vocab_size = m1
        assert (
            (m1 == m2) and (m1 == m3) and (m1 == m4)
        ), "Incorrect input dimension of transition matrix"
        assert torch.all(mat >= 0) and torch.all(
            torch.abs(mat.sum(dim=3) - 1) < 1e-6
        ), "Incorrect input of transition matrix"
        mat_full = torch.zeros(vocab_size**3, vocab_size**3).float()
        for k1 in range(vocab_size):
            for k2 in range(vocab_size):
                for k3 in range(vocab_size):
                    k = k1 * (vocab_size**2) + k2 * vocab_size + k3
                    k_out = (
                        k2 * (vocab_size**2)
                        + k3 * vocab_size
                        + torch.arange(vocab_size).long()
                    )
                    mat_full[k, k_out] = mat[k1, k2, k3, :]
    else:
        warnings.warn("The order argument receives an incorrect input.")

    return mat_full


def mask_get_along_axis(shape, indices):
    assert shape[0] == len(
        indices
    ), "length of indices show match number of rows in shape"
    mask = np.zeros(shape)
    return np.array(
        [
            np.concatenate((np.zeros(indices[i]), np.ones(shape[1] - indices[i])))
            for i in range(len(indices))
        ]
    )


def mask_get_given_starts(shape, lens, starts, ignore_segment=0, ignore_burning=0):
    n, L = shape[0], shape[1]
    n1 = len(lens)
    n2, rep = starts.shape
    assert n == n1 and n == n2, "Wrong input shapes"
    mask = np.zeros((n, L), dtype=int)
    for i in range(n):
        for j in range(ignore_segment, rep):
            mask[i, (starts[i, j] + ignore_burning) : (starts[i, j] + lens[i])] = 1
    return mask


#####################################################
##################### Making plots ######################
#####################################################


def plot_err_curve(
    err_arr,
    setting_params=None,
    fig_name=None,
    save_dir=None,
    opt_err=None,
    plot_ood=False,
    plot_train=True,
    log_training_time=False,
):
    """
    A simple function to make plots based on err_arr, optionally saving plots in the specified folder
    Args:
        err_arr: a numpy array of size num_epoch-by-6, containing train/test/ood-test loss/errors
        setting_params: a dictionary containing setting parameters such as vocab_size, max_seq_len
        opt_err: a optional 1d array showing the optimal achievable error
        plot_ood: if true, plots the loss/error curces for ood test data
    """
    num_epoch = (
        setting_params["num_epoch"] if setting_params is not None else err_arr.shape[0]
    )
    if fig_name is not None:
        if save_dir is None:
            if not os.path.isdir("Figs"):
                os.mkdir("Figs")
            save_path = os.path.join("Figs", fig_name)
        else:
            save_path = os.path.join(save_dir, fig_name)

    fig, axs = plt.subplots(1, 2, figsize=(15, 8))
    if plot_train:
        axs[0].plot(
            np.arange(num_epoch, dtype=int),
            err_arr[:, 0],
            linewidth=2,
            label="train loss",
        )
    axs[0].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 2], linewidth=2, label="test loss"
    )
    if plot_ood:
        axs[0].plot(
            np.arange(num_epoch, dtype=int),
            err_arr[:, 4],
            linewidth=2,
            label="ood test loss",
        )
    axs[0].set_yscale("log")
    axs[0].set_title(
        f"Train/test loss, last/best test epoch {err_arr[-1,1]:.3f}, {np.min(err_arr[:,1]):.3f}",
        weight="bold",
    )
    axs[0].set_xlabel("Epochs", weight="bold")
    if plot_train:
        axs[1].plot(
            np.arange(num_epoch, dtype=int),
            err_arr[:, 1],
            linewidth=2,
            label="train err",
        )
    axs[1].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 3], linewidth=2, label="test err"
    )
    if plot_ood:
        axs[1].plot(
            np.arange(num_epoch, dtype=int),
            err_arr[:, 5],
            linewidth=2,
            label="ood test err",
        )
    if opt_err is not None:
        axs[1].plot(
            np.arange(num_epoch, dtype=int),
            np.repeat(opt_err, num_epoch),
            linestyle="dashed",
            label="optimal err",
        )
    axs[1].legend()
    axs[1].set_title(
        f"Train/test error, last/best test epoch {err_arr[-1,3]:.3f}, {np.min(err_arr[:,3]):.3f}",
        weight="bold",
    )
    axs[1].set_xlabel("Epochs", weight="bold")
    # axs[1].axhline(y=0.5, xmin=0, xmax=num_epoch, linestyle="dashed", c="black")

    if log_training_time:
        axs[0].set_xscale("log")
        axs[1].set_xscale("log")

    if fig_name is None:
        plt.show()
    else:
        plt.savefig(save_path, bbox_inches="tight")

def plot_blk_spectrum(
    spectrum_list,
    param_names,
    fig_name=None,
    save_dir=None,
):
    if fig_name is not None:
        if save_dir is None:
            if not os.path.isdir("Figs"):
                os.mkdir("Figs")
            save_path = os.path.join("Figs", fig_name)
        else:
            save_path = os.path.join(save_dir, fig_name)
    fig, axs = plt.subplots(4, 5, figsize=(12, 8))
    plt.tight_layout()
    for idx, spec in enumerate(spectrum_list):
        col = idx % 5
        row = idx // 5
        axs[row][col].bar(np.arange(len(spec)), spec)
        axs[row][col].set_title(f"{param_names[idx]}")
    if fig_name is None:
        plt.show()
    else:
        plt.savefig(save_path, bbox_inches="tight")

def plot_sharpness_curve(
    sharpness_arr,
    setting_params=None,
    fig_name=None,
    save_dir=None,
):
    num_epoch = (
        setting_params["num_epoch"] if setting_params is not None else sharpness_arr.shape[0]
    )
    if fig_name is not None:
        if save_dir is None:
            if not os.path.isdir("Figs"):
                os.mkdir("Figs")
            save_path = os.path.join("Figs", fig_name)
        else:
            save_path = os.path.join(save_dir, fig_name)

    plt.figure()
    plt.plot(sharpness_arr)
    if fig_name is None:
        plt.show()
    else:
        plt.savefig(save_path, bbox_inches="tight")

def plot_err_curve_hmm(
    err_arr,
    setting_params=None,
    fig_name=None,
    save_dir=None,
    opt_err=None,
    plot_ood=False,
):
    """
    A simple function to make plots based on err_arr. Similar to plot_err_curve, but also plots errors based on hmm
    Args:
        err_arr: a numpy array of size num_epoch-by-9, containing train/test/ood-test loss/errors, and IOI train/test/ood-test errors
        setting_params: a dictionary containing setting parameters such as vocab_size, max_seq_len
        opt_err: a optional 1d array showing the optimal achievable error
        plot_ood: if true, plots the loss/error curces for ood test data
    """
    num_epoch = (
        setting_params["num_epoch"] if setting_params is not None else err_arr.shape[0]
    )
    if fig_name is not None:
        if save_dir is None:
            if not os.path.isdir("Figs"):
                os.mkdir("Figs")
            save_path = os.path.join("Figs", fig_name)
        else:
            save_path = os.path.join(save_dir, fig_name)

    fig, axs = plt.subplots(2, 3, figsize=(3 * 6, 2 * 6))
    axs[0, 0].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 0], linewidth=2, label="train loss"
    )
    axs[0, 0].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 2], linewidth=2, label="test loss"
    )
    if plot_ood:
        axs[0, 0].plot(
            np.arange(num_epoch, dtype=int),
            err_arr[:, 4],
            linewidth=2,
            label="ood test loss",
        )
    axs[0, 0].set_yscale("log")
    axs[0, 0].set_title(
        f"Train/test loss, last/best test epoch {err_arr[-1,1]:.3f}, {np.min(err_arr[:,1]):.3f}",
        weight="bold",
    )
    axs[0, 0].set_xlabel("Epochs", weight="bold")
    axs[0, 1].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 1], linewidth=2, label="train err"
    )
    axs[0, 1].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 3], linewidth=2, label="test err"
    )
    if plot_ood:
        axs[0, 1].plot(
            np.arange(num_epoch, dtype=int),
            err_arr[:, 5],
            linewidth=2,
            label="ood test err",
        )
    if opt_err is not None:
        axs[0, 1].plot(
            np.arange(num_epoch, dtype=int),
            np.repeat(opt_err, num_epoch),
            linestyle="dashed",
            label="optimal err",
        )
    axs[0, 1].legend()
    axs[0, 1].set_title(
        f"Train/test error, last/best test epoch {err_arr[-1,3]:.3f}, {np.min(err_arr[:,3]):.3f}",
        weight="bold",
    )
    axs[0, 1].set_xlabel("Epochs", weight="bold")
    axs[0, 2].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 6], linewidth=2, label="train err"
    )
    axs[0, 2].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 7], linewidth=2, label="test err"
    )
    if plot_ood:
        axs[0, 2].plot(
            np.arange(num_epoch, dtype=int),
            err_arr[:, 8],
            linewidth=2,
            label="ood test loss",
        )
    axs[0, 2].set_yscale("log")
    axs[0, 2].set_title(
        f"Train/test loss only for IOI, last/best test epoch {err_arr[-1,7]:.3f}, {np.min(err_arr[:,7]):.3f}",
        weight="bold",
    )
    axs[0, 2].set_xlabel("Epochs", weight="bold")
    axs[1, 0].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 9], linewidth=2, label="train err"
    )
    axs[1, 0].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 10], linewidth=2, label="test err"
    )
    if plot_ood:
        axs[1, 0].plot(
            np.arange(num_epoch, dtype=int),
            err_arr[:, 11],
            linewidth=2,
            label="ood test loss",
        )
    axs[1, 0].set_yscale("log")
    axs[1, 0].set_title(
        f"Train/test state prediction errors, last/best test epoch {err_arr[-1,10]:.3f}, {np.min(err_arr[:,10]):.3f}",
        weight="bold",
    )
    axs[1, 0].set_xlabel("Epochs", weight="bold")
    axs[1, 1].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 12], linewidth=2, label="train err"
    )
    axs[1, 1].plot(
        np.arange(num_epoch, dtype=int), err_arr[:, 13], linewidth=2, label="test err"
    )
    if plot_ood:
        axs[1, 1].plot(
            np.arange(num_epoch, dtype=int),
            err_arr[:, 14],
            linewidth=2,
            label="ood test loss",
        )
    axs[1, 1].set_yscale("log")
    axs[1, 1].set_title(
        f"Train/test tran matrix prediction errors, last/best test epoch {err_arr[-1,13]:.3f}, {np.min(err_arr[:,13]):.3f}",
        weight="bold",
    )
    axs[1, 1].set_xlabel("Epochs", weight="bold")
    if fig_name is None:
        plt.show()
    else:
        plt.savefig(save_path, bbox_inches="tight")


def plot_attention(
    model,
    tokens,
    fig_name,
    norm=True,
    pos=None,
    savefig_dir="Figs",
    use_mask=True,
    num_heads=1,
    layer=0,
):
    """
    This function makes two plots, namely attention plot and QK value heatmap
    Args:
        model: the simpleT model we use for the simulations
        tokens: a sequence of tokens, where each token is any element of type torch.long in the vocabulary
        fig_name: name of figure when saving plots
        is_mask: if True, use a mask when calculating QK and attention for next-token prediction
        num_heads: number of attention heads in the model
    Returns:
        QK_vals: the pre-softmax QK values, torch.Tensor 2-d array
        attn: attentions, numpy 2-d array, normalized to sum 1

    """
    model.eval()
    seq = (
        model.pos_embed(model.embed(tokens.unsqueeze(0)))
        if pos not in ["rotary", "relative"]
        else model.embed(tokens.unsqueeze(0))
    )
    _, seq_len, d_model = seq.size()
    d_k = d_model // num_heads
    if d_model % num_heads != 0:
        warnings.warn("d_model is not divisible by num_heads!")
    mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).to(model.device)

    h = seq
    for layer0 in range(layer):
        h = model.h[layer0](h, mask=mask)
    h = model.h[layer].ln_1(h) if norm else h
    queries = model.h[layer].mha.W_q(h)
    keys = model.h[layer].mha.W_k(h)
    values = model.h[layer].mha.W_v(h)
    queries = queries.view(1, seq_len, num_heads, d_k).transpose(1, 2)
    keys = keys.view(1, seq_len, num_heads, d_k).transpose(1, 2)
    values = values.view(1, seq_len, num_heads, d_k).transpose(1, 2)

    _, QKval = model.h[layer].mha.scaled_dot_product_attention(
        queries, keys, values, mask=mask
    )
    attn, QK_vals = QKval[0].squeeze(dim=0).numpy(force=True), QKval[1].squeeze(
        dim=0
    ).numpy(force=True)

    ## making plots now
    fig, axs = plt.subplots(num_heads, 3, figsize=(3 * 9, num_heads * 9))
    width = 1
    example_sep = 2
    word_height = 1
    pad = 0.1
    yoffset = 1
    xoffset = 0

    for head in range(num_heads):
        plot_idx = (head, 0) if num_heads > 1 else 0
        """
        for position, token in enumerate(tokens.numpy()):
            axs[plot_idx].text(xoffset + 0,
                     yoffset - position * word_height,
                     token,
                     ha="right",
                     va="center")
            axs[plot_idx].text(xoffset + width,
                     yoffset - position * word_height,
                     token,
                     ha="left",
                     va="center")
        axs[plot_idx].text(xoffset + 0.5 * width,
                 3,
                 "",
                 ha="center",
                 va="top",
                 weight="bold")
        for i in range(len(tokens)):
            for j in range(len(tokens)):
                axs[plot_idx].plot(
                    [xoffset + pad, xoffset + width - pad],
                    [yoffset - word_height * i, yoffset - word_height * j],
                    color="blue",
                    linewidth=1,
                    alpha=attn[head, i, j])
                axs[plot_idx].set_title(f'Post-softmax attentions: head {head}',weight="bold",fontsize=25)
        """

        plot_idx = (head, 1) if num_heads > 1 else 1
        pcm = axs[plot_idx].imshow(QK_vals[head, :, :])
        axs[plot_idx].set_title(
            f"Pre-softmax QK values: head {head}", weight="bold", fontsize=25
        )
        fig.colorbar(pcm, ax=axs[plot_idx], shrink=0.8)

        plot_idx = (head, 2) if num_heads > 1 else 2
        pcm = axs[plot_idx].imshow(attn[head, :, :])
        axs[plot_idx].set_title(
            f"Attention values: head {head}", weight="bold", fontsize=25
        )
        fig.colorbar(pcm, ax=axs[plot_idx], shrink=0.8)
    plt.savefig(os.path.join(savefig_dir, fig_name), bbox_inches="tight")
    plt.close()
    return attn, QK_vals


def plot_incoh_heatmat(save_dir, model, setting_params, remove_firstlast=True):
    train_embed_str = "trainEmbed_" if setting_params["train_embed"] else ""
    add_embed_str = "addEmbed_" if setting_params["add_embed"] else ""
    MLP_str = "MLP_" if setting_params["use_MLP"] else ""
    sig = setting_params["sig"]
    plt_save_name = MLP_str + train_embed_str + add_embed_str + f"sig_{sig}_incoh"

    fig, ax = plt.subplots(2, 2, figsize=(16, 16))

    W_e = model.embed.embed.weight.detach().numpy()
    Gram_e = W_e @ W_e.T
    pcm = ax[0, 0].imshow(Gram_e)
    fig.colorbar(pcm, ax=ax[0, 0], shrink=0.8)
    ax[0, 0].set_title("Gram matrix of static embed matrix", weight="bold")

    W_p = model.pos_embed.pe.weight.detach().numpy()
    W_p = W_p[1:-1] if remove_firstlast else W_p
    Gram_p = W_p @ W_p.T
    pcm = ax[0, 1].imshow(Gram_p)
    fig.colorbar(pcm, ax=ax[0, 1], shrink=0.8)
    ax[0, 1].set_title("Gram matrix of positional embed matrix", weight="bold")

    u, s, vt = np.linalg.svd(W_e)
    ax[1, 0].plot(s)
    ax[1, 0].set_xlabel("index")
    ax[1, 0].set_yscale("log")
    ax[1, 0].set_title("Spectrum of the static embed matrix", weight="bold")

    u, s, vt = np.linalg.svd(W_p)
    ax[1, 1].plot(s)
    ax[1, 1].set_xlabel("index")
    ax[1, 1].set_yscale("log")
    ax[1, 1].set_title("Spectrum of the positional embed matrix", weight="bold")

    plt.savefig(os.path.join(save_dir, plt_save_name), bbox_inches="tight")


def plot_err_over_pos(
    model,
    src_list,
    vocab_size,
    fig_name,
    criterion=nn.CrossEntropyLoss(reduction="none"),
    lens=None,
    src_labels=None,
    starts=None,
    return_predictions=False,
    ignore_Aa=False,
    save_dir="Figs",
):
    """
    This function makes a plot to show the error of next-token prediction at each position
    Args:
        model: the simple TF model we use for the simulations
        src_list: a list of datasets to test the model, 2D Tensor
        vocab_size: size of vocabulary
        fig_name: name of figure when saving plots
        lens: a list/array of number of positions to ignore when calculting errors (useful when we want to exclude tokens not yet repeated)
        return_predictions: if True, return predictions for each seq in src_list
        ignore_Aa: (useful in Aa setting) if True, then case sensitivitity is ignored when calculating errors
    Returns:
        loss_list:  a list of arrays of length T-1 (max_seq_len-1), average loss at each position
        err_list: a list of arrays of length T-1 (max_seq_len-1), average error at each position
        pred_list: if return_predictions is True, a list of array of length T-1

    """
    # if fig_name is not None:
    #     if save_dir is None:
    #         if not os.path.isdir("Figs"):
    #             os.mkdir("Figs")
    #         save_path = os.path.join("Figs", fig_name)
    #     else:
    #         save_path = os.path.join(save_dir, fig_name)
    # src_labels = [None] * len(src_list) if src_labels is None else src_labels
    # lens = [None] * len(src_list) if lens is None else lens
    # vocab_halfsize = vocab_size // 2
    # eps = 1e-6
    # model.eval()
    # loss_list = []
    # err_list = []
    # pred_list = []

    # for i, src in enumerate(src_list):
    #     N, T = src.size()
    #     M = torch.zeros(src.shape)
    #     if lens[i] is not None:
    #         M = torch.tensor(mask_get_along_axis(src.shape, lens[i]), device=src.device)
    #     with torch.no_grad():
    #         output = model(src)
    #         loss = (
    #             criterion(
    #                 output[:, :-1].contiguous().view(-1, vocab_size),
    #                 src[:, 1:].contiguous().view(-1),
    #             )
    #             .reshape(N, T - 1)
    #             .mean(dim=0)
    #         )
    #         pred = output.argmax(dim=2)[:, :-1]
    #         if ignore_Aa:
    #             tmp = (
    #                 (pred == src[:, 1:])
    #                 | (pred == src[:, 1:] + vocab_halfsize)
    #                 | (pred == src[:, 1:] - vocab_halfsize)
    #             )
    #         else:
    #             tmp = pred == src[:, 1:]
    #         err = 1 - torch.sum(tmp * M[:, :-1], dim=0) / (
    #             torch.sum(M[:, :-1], dim=0) + eps
    #         )
    #     loss_list.append(loss.numpy(force=True))
    #     err_list.append(err.numpy(force=True))
    #     pred_list.append(pred.numpy(force=True))

    # fig, axs = plt.subplots(2, 1, figsize=(10, 6 * 2))
    # for i, src in enumerate(src_list):
    #     axs[0].plot(
    #         np.arange(T - 1, dtype=int), loss_list[i], "-o", label=src_labels[i]
    #     )
    #     axs[1].plot(np.arange(T - 1, dtype=int), err_list[i], "-o", label=src_labels[i])
    # axs[0].set_xlabel("Position", weight="bold")
    # axs[0].set_ylabel("Loss", weight="bold")
    # axs[0].set_title(
    #     "Averaged next-token prediction loss at each position", weight="bold"
    # )
    # axs[1].set_xlabel("Position", weight="bold")
    # axs[1].set_ylabel("Error", weight="bold")
    # axs[1].set_title(
    #     "Averaged next-token prediction error at each position", weight="bold"
    # )
    # axs[0].legend()
    # axs[1].legend()

    # if fig_name is None:
    #     plt.show()
    # else:
    #     plt.savefig(save_path, bbox_inches="tight")

    # out = (
    #     (loss_list, err_list, pred_list)
    #     if return_predictions
    #     else (loss_list, err_list)
    # )
    # return out
    if fig_name is not None:
        if save_dir is None:
            if not os.path.isdir("Figs"):
                os.mkdir("Figs")
            save_path = os.path.join("Figs", fig_name)
        else:
            save_path = os.path.join(save_dir, fig_name)
    src_labels = [None] * len(src_list) if src_labels is None else src_labels
    lens = [None] * len(src_list) if lens is None else lens
    vocab_halfsize = vocab_size // 2
    eps = 1e-6
    model.eval()
    loss_list = []
    err_list = []
    err2_list = []
    pred_list = []

    for i, src in enumerate(src_list):
        N, T = src.size()
        M = torch.zeros(src.shape)
        if lens[i] is not None:
            M = torch.tensor(mask_get_given_starts(src.shape, lens[i], starts[i]))
        with torch.no_grad():
            output = model(src)
            loss = (
                criterion(
                    output[:, :-1].contiguous().view(-1, vocab_size),
                    src[:, 1:].contiguous().view(-1),
                )
                .reshape(N, T - 1)
                .mean(dim=0)
            )
            pred = output.argmax(dim=2)[:, :-1]
            if ignore_Aa:
                tmp = (
                    (pred == src[:, 1:])
                    | (pred == src[:, 1:] + vocab_halfsize)
                    | (pred == src[:, 1:] - vocab_halfsize)
                )
            else:
                tmp = pred == src[:, 1:]
        err = 1 - torch.sum(tmp.cpu() * M[:, :-1], dim=0) / (
            torch.sum(M[:, :-1], dim=0) + eps
        )  # averaged err at each posiiton
        err2 = torch.zeros(vocab_size)
        for j in range(vocab_size):
            err2[j] = 1 - torch.sum((tmp * (src[:, :-1] == j)).cpu() * M[:, :-1]) / (
                torch.sum((src[:, :-1] == j).cpu() * M[:, :-1]) + eps
            )  # averaged err at each token
        loss_list.append(loss.numpy(force=True))
        err_list.append(err.numpy(force=True))
        err2_list.append(err2.numpy(force=True))
        pred_list.append(pred.numpy(force=True))

    fig, axs = plt.subplots(3, 1, figsize=(10, 6 * 3))
    for i, src in enumerate(src_list):
        axs[0].plot(
            np.arange(T - 1, dtype=int), loss_list[i], "-o", label=src_labels[i]
        )
        axs[1].plot(np.arange(T - 1, dtype=int), err_list[i], "-o", label=src_labels[i])
        axs[2].plot(
            np.arange(vocab_size, dtype=int), err2_list[i], "-o", label=src_labels[i]
        )
    axs[0].set_xlabel("Position", weight="bold")
    axs[0].set_ylabel("Loss", weight="bold")
    axs[0].set_title(
        "Averaged next-token prediction loss at each position", weight="bold"
    )
    axs[1].set_xlabel("Position", weight="bold")
    axs[1].set_ylabel("Error", weight="bold")
    axs[1].set_title(
        "Averaged next-token prediction error at each position", weight="bold"
    )
    axs[2].set_xlabel("Token", weight="bold")
    axs[2].set_ylabel("Error", weight="bold")
    axs[2].set_title(
        "Averaged next-token prediction error at each token", weight="bold"
    )
    axs[0].legend()
    axs[1].legend()
    axs[2].legend()

    if fig_name is None:
        plt.show()
    else:
        plt.savefig(save_path, bbox_inches="tight")
        plt.close()

    out = (
        (loss_list, err_list, err2_list, pred_list)
        if return_predictions
        else (loss_list, err_list)
    )
    return out


def plot_qk_subspace_matching(model, fig_name, config):
    num_svals_plot = 32
    W_q = model.h[1].mha.W_q.weight.numpy(force=True)
    W_k = model.h[1].mha.W_k.weight.numpy(force=True)
    W_v = model.h[0].mha.W_v.weight.numpy(force=True)
    W_o = model.h[0].mha.W_o.weight.numpy(force=True)
    W_qk = W_q.T @ W_k / np.sqrt(config.d_model)
    W_ov = W_o @ W_v
    U_qk, s_qk, Vt_qk = np.linalg.svd(W_qk)
    U_ov, s_ov, Vt_ov = np.linalg.svd(W_ov)

    s_match = np.zeros((2, num_svals_plot))
    for j in range(num_svals_plot):
        _, s, _ = np.linalg.svd(Vt_qk[: (j + 1), :] @ U_ov[:, : (j + 1)])
        _, s2, _ = np.linalg.svd(Vt_ov[: (j + 1), :] @ U_qk[:, : (j + 1)])
        s_match[0, j] = s[0]
        s_match[1, j] = s2[0]

    fig, axs = plt.subplots(1, 2, figsize=(6 * 2, 6 * 1))
    axs[0].plot(s_qk[:num_svals_plot] / s_qk[0], "-o", label="qk", linewidth=2)
    axs[0].plot(s_ov[:num_svals_plot] / s_ov[0], "-o", label="ov", linewidth=2)
    axs[0].plot(s_match[0, :num_svals_plot], "-o", label="match")
    axs[0].legend()
    axs[0].set_title("inner match")
    axs[1].plot(s_qk[:num_svals_plot] / s_qk[0], "-o", label="qk", linewidth=2)
    axs[1].plot(s_ov[:num_svals_plot] / s_ov[0], "-o", label="ov", linewidth=2)
    axs[1].plot(s_match[1, :num_svals_plot], "-o", label="match")
    axs[1].legend()
    axs[1].set_title("outer match")

    plt.savefig(fig_name)
    plt.close()


def plot_wqkov(model, fig_name, config):
    W_q = model.h[1].mha.W_q.weight.numpy(force=True)
    W_k = model.h[1].mha.W_k.weight.numpy(force=True)
    W_v = model.h[0].mha.W_v.weight.numpy(force=True)
    W_o = model.h[0].mha.W_o.weight.numpy(force=True)
    W_qk = W_q.T @ W_k / np.sqrt(config.d_model)
    W_ov = W_o @ W_v
    W_qkov = W_qk @ W_ov

    ax = sns.heatmap(W_qkov, square=True)
    plt.savefig(fig_name)
    plt.close()


#####################################################
################# Calculating errors ################
#####################################################


def hmm_calc_err(
    src_list, output_list, states_list, state_sizes, transition_mat, contains_ood=True
):
    """
    This function caluculates the errors after every epoch during training
    Args:
        src_list is a list containing train data, test data, ood data
        output_list is a list containing next-token prediction probabilities based on train data, test data, ood data
        states_list is a list containing hidden markov states for train data, test data, ood data
        if contains_ood is False, the above lists do not contain ood data related data
    Returns:
        err_ratios: a list containing train/text/ood next-token prediction errors
        err_states_ratios: a list containing errors for prediction hidden states, on train/text/ood respectively
        err_probs_ratios: predicted transition matrix vs. true transition matrix, measured under L_1 loss, on train/text/ood respectively
    """
    K = len(state_sizes)
    assert np.all(
        np.array([state_sizes[k] == state_sizes[0] for k in range(K)])
    ), "Currently only support identical state sizes"
    s = state_sizes[0]
    err_ratios = torch.zeros(len(src_list))
    err_states_ratios = torch.zeros(len(src_list))
    err_probs_ratios = torch.zeros(len(src_list))
    for k, (src, output, states) in enumerate(zip(src_list, output_list, states_list)):
        sample_size, T, vocab_size = output.size(0), output.size(1), output.size(2)
        pred = output.argmax(dim=2)
        # cleaning; remove examples from counting errors if too few states are 0
        nums_zero_state = torch.sum(states == 0, dim=1)
        id_keep = (
            nums_zero_state > 3
        )  # remove some instances from dataset if too few satisfy states==0 such that ioi is impossible
        sample_size_effective = torch.sum(id_keep)
        src, states, output, pred = (
            src[id_keep],
            states[id_keep],
            output[id_keep],
            pred[id_keep],
        )
        # read states and probabilities from pred/output
        pred_states = pred // s
        pred_states = (
            pred_states * (pred_states < K)
        ).long()  # treating special symbols as having state 0
        probs = F.softmax(output, dim=-1)
        state_probs = probs[:, :, : (s * K)].view(sample_size, T, K, s).sum(axis=3)
        state_probs[:, :, 0] += (
            probs[:, :, (s * K) :].view(sample_size, T, -1).sum(axis=2)
        )  # special symbols combined with state 0
        pred_probs = torch.zeros(K, K)
        for j1 in range(K):
            for j2 in range(K):
                pred_probs[j1, j2] = torch.sum(
                    (states == j1) * state_probs[:, :, j2]
                ) / torch.sum(states == j1)

        loc1, loc2 = torch.nonzero(states == 0, as_tuple=True)
        # loc1, loc2 = loc1[loc2!=0], loc2[loc2!=0] # NOTE: ignore this for now
        total_err = torch.sum(src[loc1, loc2] != pred[loc1, loc2 - 1])
        total_zero_state = len(loc1)
        err_ratios[k] = (total_err - sample_size_effective) / (
            total_zero_state - sample_size_effective
        )
        err_states_ratios[k] = torch.mean(
            (pred_states[:, 2:-1] != states[:, 3:]).float()
        )
        err_probs_ratios[k] = torch.sum(torch.abs(pred_probs - transition_mat))

    return err_ratios, err_states_ratios, err_probs_ratios


######### making various plots ############


def plots_maker(
    model,
    config,
    src_list,
    starts=None,
    epoch=None,
    lens=None,
    save_dir=None,
):
    """
    Making various plots for a model during/after training
    """
    assert len(src_list) == 3, "Only supports includuing OOD data"
    src, src_test, src_test_ood = src_list
    src_labels = ["train", "test", "ood"]
    num_layers = config.num_layers
    d_model = config.d_model

    # plot errors at each position
    _ = plot_err_over_pos(
        model,
        [src, src_test, src_test_ood],
        config.vocab_size,
        f"err_over_pos_epoch_{epoch}",
        lens=lens,
        starts=starts,
        src_labels=["train", "test", "ood"],
        save_dir=save_dir,
    )

    # plot Gram matrix
    if config.pos not in ["rotary", "relative"]:
        wpe = F.normalize(model.pos_embed.pe.weight, dim=-1)
        wte = F.normalize(model.embed.embed.weight, dim=-1)
        basis = torch.concat([wpe, wte], dim=0).detach().numpy()
        Gram = basis @ basis.T

        fig, axs = plt.subplots(1, 1, figsize=(12 * 1, 12 * 1))
        sns.heatmap(Gram, ax=axs, vmin=-1, vmax=1, cmap="bwr")
        axs.set_title("Gram matrix: [pos, token]", weight="bold")
        plt.savefig(os.path.join(save_dir, f"Gram_matrix_epoch_{epoch}"))
        plt.close()

    # plot attention
    attn_dir = os.path.join(save_dir, "attn")
    create_folder(attn_dir)
    attn_list = {}
    QK_list = {}
    for layer in range(num_layers):
        attn_list[layer] = []
        QK_list[layer] = []
        for k, src0 in enumerate(src_list):
            attn, QK = plot_attention(
                model,
                src0[0, :],
                pos=config.pos,
                layer=layer,
                num_heads=config.num_heads,
                norm=config.norm,
                fig_name=f"_{src_labels[k]}_{layer}_epoch_{epoch}",
                savefig_dir=attn_dir,
            )
            attn_list[layer].append(attn)
            QK_list[layer].append(QK)

    if num_layers == 1:
        layer1 = 0
        W_q = model.h[layer1].mha.W_q.weight.numpy(force=True)
        W_k = model.h[layer1].mha.W_k.weight.numpy(force=True)
        W_v = model.h[layer1].mha.W_v.weight.numpy(force=True)
        W_o = model.h[layer1].mha.W_o.weight.numpy(force=True)
        W_qk = W_q.T @ W_k / np.sqrt(d_model)
        W_ov = W_o @ W_v
        fig, axs = plt.subplots(1, 2, figsize=(6 * 2, 6 * 1))
        sns.heatmap(W_qk, ax=axs[0])
        axs[0].set_title("W_qk", weight="bold")
        sns.heatmap(W_ov, ax=axs[1])
        axs[1].set_title("W_ov", weight="bold")
        plt.savefig(os.path.join(save_dir, f"QK_OV_visz_epoch_{epoch}"))
        plt.close()
        return

    # plot weight matrices
    for layer1 in range(num_layers):
        for layer2 in range(layer1 + 1, num_layers):
            W_q = model.h[layer2].mha.W_q.weight.numpy(force=True)
            W_k = model.h[layer2].mha.W_k.weight.numpy(force=True)
            W_v = model.h[layer1].mha.W_v.weight.numpy(force=True)
            W_o = model.h[layer1].mha.W_o.weight.numpy(force=True)
            W_qk = W_q.T @ W_k / np.sqrt(d_model)
            W_ov = W_o @ W_v
            W_qkov = W_qk @ W_ov
            fig, axs = plt.subplots(1, 3, figsize=(6 * 3, 6 * 1))
            sns.heatmap(W_qk, ax=axs[0], square=True)
            axs[0].set_title("W_qk", weight="bold")
            sns.heatmap(W_ov, ax=axs[1], square=True)
            axs[1].set_title("W_ov", weight="bold")
            sns.heatmap(W_qkov, ax=axs[2], square=True)
            axs[2].set_title("W_qkov", weight="bold")
            plt.savefig(
                os.path.join(
                    save_dir, f"QK_OV_visz_layer_{layer1}_{layer2}_epoch_{epoch}"
                )
            )
            plt.close()

    # plot subspace matching
    num_svals_plot = 32
    U_qk, s_qk, Vt_qk = np.linalg.svd(W_qk)
    U_ov, s_ov, Vt_ov = np.linalg.svd(W_ov)
    s_match = np.zeros((2, num_svals_plot))
    for j in range(num_svals_plot):
        _, s, _ = np.linalg.svd(Vt_qk[: (j + 1), :] @ U_ov[:, : (j + 1)])
        _, s2, _ = np.linalg.svd(Vt_ov[: (j + 1), :] @ U_qk[:, : (j + 1)])
        s_match[0, j] = s[0]
        s_match[1, j] = s2[0]

    fig, axs = plt.subplots(1, 2, figsize=(6 * 2, 6 * 1))
    axs[0].plot(s_qk[:num_svals_plot] / s_qk[0], "-o", label="qk", linewidth=2)
    axs[0].plot(s_ov[:num_svals_plot] / s_ov[0], "-o", label="ov", linewidth=2)
    axs[0].plot(s_match[0, :num_svals_plot], "-o", label="match")
    axs[0].legend()
    axs[0].set_title("inner match")
    axs[1].plot(s_qk[:num_svals_plot] / s_qk[0], "-o", label="qk", linewidth=2)
    axs[1].plot(s_ov[:num_svals_plot] / s_ov[0], "-o", label="ov", linewidth=2)
    axs[1].plot(s_match[1, :num_svals_plot], "-o", label="match")
    axs[1].legend()
    axs[1].set_title("outer match")
    plt.savefig(os.path.join(save_dir, f"subspace_matching_{epoch}"))
    plt.close()

# Setup Running config and setting
setup configuration and running function

In [44]:
# Configuration class
class Config:
    """
    This is the configuration class to store the configuration of a TFModel. It is used to
    instantiate a model according to the specified arguments, defining the model architecture.
    """
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

# Print PyTorch version
print(f"PyTorch version: {torch.__version__}")


# Cell 2: Define Configuration Dictionary (replacing config_0.yaml)
config_dict = {
    # reproduce
    "seed": 2026,
    
    # model
    "vocab_size": 16,  # 32
    "d_model": 32,  # 64
    "ff_dim": 256,
    "num_heads": 1,
    "num_layers": 2,
    
    # TF model variants
    "linear_attn": False,
    "residual": True,
    "mlp": False,
    "dropout": 0.1,
    "norm": True,
    "output_norm": False,
    "trainable_norm": False,
    "pos": "rotary",
    "rotary_theta": 10000,
    
    # data generation
    "max_seq_len": 64,
    "sample_size": 5000,
    "sample_size_test": 5000,
    "regime": "varied repetition",
    "distr": "two-level",
    "rep_l": 10,
    "rep_h": 20,
    "ood_len_pattern": 25,
    "pool_size": None,
    "sig": 2,
    
    # training
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "fancy_opt": False,
    "use_wd": True,
    "schedule": "constant",
    "fresh_sample": True,
    "label_smoothing": False,
    "optimizer": "sam",
    "lr": 0.001,
    "wd": 0.0005,
    "batch_size": 50,
    "num_epoch": 10000,
    
    # logging
    "wandb_log": False,
    "plot_attn_every_epoch": 100,
    "print_output": False,
    "n_save": 1,  # 500
    "up_to_first_save": False,
    
    # eval
    "ignore_segment": 1,
    "ignore_burning": 4,
    
    # IO
    "out_dir": "out_sam_0.001",
    
    # sharpness task
    "sharpness_step": 1000,
    "sharpness_task": None
}


PyTorch version: 2.7.0+cu126


In [45]:
def make_scheduler(optimizer, config):
    if config.schedule == "constant":
        scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
    elif config.schedule == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, config.num_epoch
        )
    return scheduler

In [46]:
def run_experiment(config_override=None):
    """
    Main function to run the experiment.
    Args:
        config_override: Dictionary with parameters to override from default config
    """
    # Create configuration
    config_args = config_dict.copy()
    if config_override:
        for k, v in config_override.items():
            if k not in config_args:
                print(f"Warning: {k} is not supported!")
            if v != config_args[k]:
                print(f"{k} is overloaded from {config_args[k]} to {v}")
                config_args[k] = v
    
    config = Config(**config_args)
    
    # Set random seed
    fix_random_seed(config.seed, reproduce=True)
    
    # Create output directories
    create_folder(config.out_dir)
    create_folder(os.path.join(config.out_dir, "figures"))
    
    # Print and save configuration
    print("Configuration:")
    for k, v in config.__dict__.items():
        print(f"  {k}: {v}")
    
    with open(os.path.join(config.out_dir, "config.json"), "w") as f:
        json.dump(config.__dict__, f, indent=2)
    
    # Initialize model
    model = TFModel(config).to(config.device)
    
    # Save initial model
    out_path = os.path.join(config.out_dir, "ckpt_0.pt")
    torch.save(model.state_dict(), out_path)
    
    # Setup optimizer
    use_sam = False
    if config.optimizer == "adamw":
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.lr,
            betas=(0.9, 0.98),
            eps=1e-9,
            weight_decay=config.wd,
        )
    elif config.optimizer == "sgd":
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=config.lr,
            momentum=0.9,
            nesterov=True,
            weight_decay=config.wd,
        )
    elif config.optimizer == "sam":
        use_sam = True
        base_optimizer = torch.optim.SGD
        optimizer = SAM(
            model.parameters(), 
            base_optimizer, 
            lr=config.lr,
            momentum=0.9,
            nesterov=True,
            weight_decay=config.wd,
        )
    
    # Setup scheduler
    scheduler = make_scheduler(optimizer, config)
    
    # Train model
    if config.fresh_sample:
        model, err_arr, err_arr_json = train_infinite(
            model=model,
            config=config,
            optimizer=optimizer,
            scheduler=scheduler,
            use_sam=use_sam,
        )
    else:
        model, err_arr, err_arr_json = train_finite(
            model=model,
            config=config,
            optimizer=optimizer,
            scheduler=scheduler,
        )
    
    model.eval()
    
    # Save final model
    out_path = os.path.join(config.out_dir, "ckpt.pt")
    torch.save(model.state_dict(), out_path)
    
    # Plot results
    plot_err_curve(
        err_arr,
        fig_name="train_test_curves",
        save_dir=config.out_dir,
        plot_ood=True,
        plot_train=not config.fresh_sample,
        log_training_time=config.fresh_sample,
    )
    
    # Save error arrays
    with open(os.path.join(config.out_dir, "err_arr.json"), "w") as f:
        json.dump(err_arr_json, f, indent=2)
    
    print(f"Experiment completed. Results saved in {config.out_dir}")
    
    return model, err_arr, err_arr_json

# Actual run experiment

In [47]:
model, err_arr, err_arr_json = run_experiment()

Configuration:
  seed: 2026
  vocab_size: 16
  d_model: 32
  ff_dim: 256
  num_heads: 1
  num_layers: 2
  linear_attn: False
  residual: True
  mlp: False
  dropout: 0.1
  norm: True
  output_norm: False
  trainable_norm: False
  pos: rotary
  rotary_theta: 10000
  max_seq_len: 64
  sample_size: 5000
  sample_size_test: 5000
  regime: varied repetition
  distr: two-level
  rep_l: 10
  rep_h: 20
  ood_len_pattern: 25
  pool_size: None
  sig: 2
  device: cuda
  fancy_opt: False
  use_wd: True
  schedule: constant
  fresh_sample: True
  label_smoothing: False
  optimizer: sam
  lr: 0.001
  wd: 0.0005
  batch_size: 50
  num_epoch: 10000
  wandb_log: False
  plot_attn_every_epoch: 100
  print_output: False
  n_save: 1
  up_to_first_save: False
  ignore_segment: 1
  ignore_burning: 4
  out_dir: out_sam_0.001
  sharpness_step: 1000
  sharpness_task: None


KeyboardInterrupt: 

# Plot

In [None]:
plt.figure(figsize=(10, 6))
if isinstance(err_arr, dict):
    for key, values in err_arr.items():
        plt.plot(values, label=key)
else:
    plt.plot(err_arr)
plt.xlabel('Epoch')
plt.ylabel('Error')
plt.title('Training Progress')
plt.legend()
plt.grid(True)
plt.show()