In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import numpy as np
import random

import torch
from torch import nn
import torch.nn.functional as F
from torchaudio.models.decoder import ctc_decoder

from architecture import Model, S4Model, H3Model
from data_utils import combine_fixed_length, decollate_tensor

In [2]:

class LanguageTask:
    """Base class for language tasks.

    Each task should implement the following methods:
        __init__: Initialize the task.
        generate_sequence: Generate a sequence of tokens for the task.
    """

    def generate_sequence(self):
        """This is a method must be implemented by each task."""
        raise NotImplementedError

    def get_embedding_layer(self, embedding_dim):
        """Return an embedding layer for the task vocabulary."""
        return nn.Embedding(len(self.alphabet), embedding_dim)

    def encode_sequence(self, sequence):
        """Convert a sequence of tokens to a tensor of indices."""
        return torch.tensor([self.vocab[token] for token in sequence])

    def decode_sequence(self, sequence):
        sequence = sequence.tolist()
        """Convert a tensor of indices back into a sequence of corresponding tokens."""
        return [self.inv_vocab[token] for token in sequence]

    
    
class InductionHeadTask(LanguageTask):
    """Toy language task to test for "copying" capabilities.

    The task is to learn to repeat the token that is shown after the special token '_'.
    We use an alphabet of 19 standard letters, plus the special token.
    The model is trained on sequences of 30 tokens. In each sequence, the pair ('_' + letter) is shown twice,
    and all other tokens are sampled randomly with replacement from the set of standard letters.

    Example sequences:
        'a _ b c a c _ b'
        'b _ a c b c _ a'
    """

    name = "induction_head"
    seq_length = 30
    normal     = list("abcdefghijklmnopqrs")
    special    = "_"
    alphabet   = normal + [special]
    vocab      = {token: idx for idx, token in enumerate(alphabet)}
    inv_vocab  = {idx: token for token, idx in vocab.items()}

    def generate_sequence(self):
        answer = self.special + np.random.choice(self.normal)
        base = list(np.random.choice(self.normal, self.seq_length - 4, replace=True))
        base = list(np.random.permutation(base + [answer])) 
        return list("".join(base + [answer])) # convert to list of chars

    
def makeBatch(task, batch_size = 16):
    seqs = torch.stack([task.encode_sequence(task.generate_sequence()) for i in range(batch_size)], dim = 0)
    return seqs.unsqueeze(-1).float()

In [25]:
from absl import flags

task = InductionHeadTask()

FLAGS = flags.FLAGS
FLAGS([''])

FLAGS.num_layers = 2
FLAGS.model_size = 256
FLAGS.dropout    = 0.

model = H3Model(num_features = 1, num_outs = len(task.alphabet)).to('cuda')
print(model)

H3Model(
  (encoder): Linear(in_features=1, out_features=256, bias=True)
  (h3_layers): ModuleList(
    (0): H3(
      (q_proj): Linear(in_features=256, out_features=256, bias=True)
      (k_proj): Linear(in_features=256, out_features=256, bias=True)
      (v_proj): Linear(in_features=256, out_features=256, bias=True)
      (s4d): S4(
        (kernel): SSKernel(
          (kernel): SSKernelDiag()
        )
        (activation): GELU(approximate='none')
        (dropout): Identity()
        (output_linear): Sequential(
          (0): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
          (1): GLU(dim=-2)
        )
      )
      (shift): Shift()
    )
    (1): H3(
      (q_proj): Linear(in_features=256, out_features=256, bias=True)
      (k_proj): Linear(in_features=256, out_features=256, bias=True)
      (v_proj): Linear(in_features=256, out_features=256, bias=True)
      (s4d): S4(
        (kernel): SSKernel(
          (kernel): SSKernelDiag()
        )
        (activation): GELU(ap

In [26]:
batch = makeBatch(task).to('cuda')
print('Batch, Time, Channels:', batch.shape)

model.forward(None, batch, None)

Batch, Time, Channels: torch.Size([16, 30, 1])
torch.Size([16, 256, 30])
torch.Size([16, 256, 30])


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!