In [None]:
#| default_exp models

In [None]:
from hydra import initialize, compose
from omegaconf import OmegaConf

In [None]:
with initialize(config_path="../configs", version_base=None):
    cfg = compose("base")

In [None]:
print(OmegaConf.to_yaml(cfg))

wandb:
  project: looped-transformer
  log_every_steps: 100
gpu:
  cuda: true
model:
  family: gpt2
  n_embd: 256
  n_layer: 12
  n_head: 8
  n_dims: 20
  n_positions: 101
  dropout: 0.0
  bias: true
task_name: linear_regression
task:
  n_points: 30
  n_dim: 20
  std: 0.1
  sparsity: null
training:
  batch_size: 64
  learning_rate: 0.0001
  weight_decay: 0.0
  train_steps: 500000
  save_every_steps: 1000
  keep_every_steps: 100000
  curriculum:
    dims:
      start: 5
      end: 20
      inc: 1
      interval: 5000
    points:
      start: 11
      end: 41
      inc: 2
      interval: 5000
    loops:
      start: 1
      end: 1
      inc: 2
      interval: 500
  n_loop_window: 20
out_dir: ./results2/linear_regression_baseline
debug_mode: false



In [None]:
#| export
import torch
import math
from looped_experiments.nano_gpt import Block, LayerNorm
from torch import nn
import torch.nn.functional as F

In [None]:
#| export
class TransformerBase(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        block_size = cfg.n_positions * 2 + 1  # input, output pairs + 1 for the target
        self.transformer = nn.ModuleDict(dict(
            wpe=nn.Embedding(block_size, cfg.n_embd),
            drop=nn.Dropout(cfg.dropout),
            h=nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)]),
            ln_f=LayerNorm(cfg.n_embd, bias=cfg.bias),
        ))
        if self.__class__ == TransformerBase:
            self._init_all_params(cfg.n_layer)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            print(f"Initializing {module}")
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            print(f"Initializing emb {module}")
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def _init_all_params(self):
        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.cfg.n_layer))

    def forward(self, embs):
        device = embs.device
        _, t = embs.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device)

        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(embs + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        return x

In [None]:
#| export
class Transformer(TransformerBase):
    '''Transformer for tasks from in-context learning'''

    def __init__(self, cfg):
        super().__init__(cfg)
        self.read_in = nn.Linear(cfg.n_dims, cfg.n_embd)
        self.read_out = nn.Linear(cfg.n_embd, 1)

        self._init_all_params()

    def create_prompt(self, xs, ys):
        n_dim = xs.shape[-1]
        y_wide = F.pad(ys.unsqueeze(-1), (0, n_dim - 1), value=0)
        return torch.stack((xs, y_wide), dim=2).flatten(1, 2)[0]

    def forward(self, xs, ys):
        x = self.create_prompt(xs, ys)
        x = self.read_in(x)
        x = super().forward(x)
        y = self.read_out(x)
        # y = y[:, self.ind::self.freq, 0] #TODO understand what is this
        return y

In [None]:
from looped_experiments.tasks import LinearRegression, dataloader

In [None]:
model = Transformer(cfg.model)
dl = dataloader(LinearRegression(cfg.training.batch_size, **cfg.task))

Initializing emb Embedding(203, 256)
Initializing Linear(in_features=256, out_features=768, bias=True)
Initializing Linear(in_features=256, out_features=256, bias=True)
Initializing Linear(in_features=256, out_features=1024, bias=True)
Initializing Linear(in_features=1024, out_features=256, bias=True)
Initializing Linear(in_features=256, out_features=768, bias=True)
Initializing Linear(in_features=256, out_features=256, bias=True)
Initializing Linear(in_features=256, out_features=1024, bias=True)
Initializing Linear(in_features=1024, out_features=256, bias=True)
Initializing Linear(in_features=256, out_features=768, bias=True)
Initializing Linear(in_features=256, out_features=256, bias=True)
Initializing Linear(in_features=256, out_features=1024, bias=True)
Initializing Linear(in_features=1024, out_features=256, bias=True)
Initializing Linear(in_features=256, out_features=768, bias=True)
Initializing Linear(in_features=256, out_features=256, bias=True)
Initializing Linear(in_features=2