python train_rmc.py --cuda --adaptivesoftmax --cutoffs 1000 5000 20000

In [None]:
from tinygrad import Tensor, nn, TinyJit
from typing import List, Callable
import pickle
import data

In [None]:
head_size = 192
num_heads = 4
key_size = 64
mem_slots = 1

In [None]:
bptt = 100//2
batch_size = 64//2
def batchify(data, bsz):
    nbatch = data.size(0) // bsz
    data = data[:nbatch * bsz]
    data = data.view(bsz, -1).T.contiguous()
    return data

loadfile = open('relational-rnn-pytorch\data\corpus-wikitext-2.pkl', 'rb')
corpus = pickle.load(loadfile)

data_T = Tensor(corpus.train.numpy())
train_data = batchify(data_T, batch_size)
ntokens = len(corpus.dictionary)
# train_data = batchify(corpus.train, batch_size)

In [None]:
class RelationalMemory:
    def __init__(self, mem_slots=1, head_size=192//2, input_size=192//2, num_tokens=ntokens, num_heads=4, num_blocks=1, attention_mlp_layers=3, key_size=64//2):
        ########## generic parameters for RMC ##########
        self.head_size = head_size
        self.num_heads = num_heads
        self.mem_size = self.head_size * self.num_heads
        self.mem_slots = mem_slots

        self.mem_slots_plus_input = self.mem_slots + 1
        self.num_blocks = num_blocks
        self.attention_mlp_layers = attention_mlp_layers

        ########## parameters for multihead attention ##########
        self.key_size = key_size if key_size else self.head_size
        self.value_size = self.head_size
        self.qkv_size = 2 * self.key_size + self.value_size
        self.total_qkv_size = self.qkv_size * self.num_heads  # denoted as F


        self.qkv_projector = nn.Linear(self.mem_size, self.total_qkv_size)
        self.qkv_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.total_qkv_size])

        self.attention_mlp: List[Callable[[Tensor], Tensor]] = [*[nn.Linear(self.mem_size, self.mem_size), Tensor.relu]*self.attention_mlp_layers]
        self.attended_memory_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size])
        self.attended_memory_layernorm2 = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size])

        ########## parameters for initial embedded input projection ##########
        self.input_size = input_size
        self.input_projector = nn.Linear(self.input_size, self.mem_size)

        ########## parameters for gating ##########
        self.input_gate_projector = nn.Linear(self.mem_size, self.mem_size * 2)
        self.memory_gate_projector = nn.Linear(self.mem_size, self.mem_size * 2)



        ########## parameters for token-to-embed & output-to-token logit for softmax
        self.dropout = Tensor.dropout
        self.num_tokens = num_tokens # number of unique tokens
        self.token_to_input_encoder = nn.Embedding(self.num_tokens, self.input_size)
        self.output_to_embed_decoder = nn.Linear(self.mem_slots * self.mem_size, self.input_size)
        
        self.embed_to_logit_decoder = nn.Linear(self.input_size, self.num_tokens)
        self.embed_to_logit_decoder.weight = self.token_to_input_encoder.weight

        # self.criterion = Tensor.los

    def repackage_hidden(self, h):
        if isinstance(h, Tensor):
            return h.detach()
        else:
            return tuple(self.repackage_hidden(v) for v in h)
        
    def initial_state(self, batch_size):
        init_state = Tensor.stack(*[Tensor.eye(self.mem_slots) for _ in range(batch_size)]) # (64, 1, 1)
        difference = self.mem_size - self.mem_slots
        pad = Tensor.zeros((batch_size, self.mem_slots, difference))  # (64, 1, 767)
        init_state = Tensor.cat(*[init_state, pad], dim=2)
        return init_state


    def multihead_attention(self, memory):
        qkv = self.qkv_projector(memory)
        qkv = self.qkv_layernorm(qkv)

        mem_slots = memory.shape[1]  # denoted as N

        qkv_reshape = qkv.view(qkv.shape[0], mem_slots, self.num_heads, self.qkv_size)
        qkv_transpose = qkv_reshape.permute(0, 2, 1, 3)
        q, k, v = Tensor.split(qkv_transpose, [self.key_size, self.key_size, self.value_size], -1)

        q = q * (self.key_size ** -0.5)

        dot_product = q @ k.permute(0, 1, 3, 2)
        weights = Tensor.softmax(dot_product, axis=-1)

        output = weights @ v
        output_transpose = output.permute(0, 2, 1, 3).contiguous()
        new_memory = output_transpose.view(output_transpose.shape[0], output_transpose.shape[1], -1)
        return new_memory
    
    def create_gates(self, inputs, memory):
        memory = Tensor.tanh(memory)
        
        inputs = inputs.view(inputs.shape[0], -1)
        gate_inputs = self.input_gate_projector(inputs)
        gate_inputs = gate_inputs.unsqueeze(dim=1)
        gate_memory = self.memory_gate_projector(memory)

        gates = gate_memory + gate_inputs
        gates = Tensor.split(gates, sizes=int(gates.shape[2] / 2), dim=2)
        input_gate, forget_gate = gates

        input_gate = Tensor.sigmoid(input_gate)
        forget_gate = Tensor.sigmoid(forget_gate)

        return input_gate, forget_gate


    def attend_over_memory(self, memory):
        for _ in range(self.num_blocks):
            attended_memory = self.multihead_attention(memory)
            memory = self.attended_memory_layernorm(memory + attended_memory)
            attention_mlp_T = memory

            attention_mlp_T = attention_mlp_T.sequential(self.attention_mlp)
            memory = self.attended_memory_layernorm2(memory + attention_mlp_T)

        return memory
    
    def forward_step(self, inputs, memory):
        inputs = inputs.view(-1, 1)
        inputs_embed = self.dropout(self.token_to_input_encoder(inputs))
        inputs_embed = inputs_embed.view(inputs_embed.shape[0], -1)
        inputs_embed = self.input_projector(inputs_embed)

        inputs_reshape = inputs_embed.unsqueeze(dim=1)
        memory_plus_input = Tensor.cat(*[memory, inputs_reshape], dim=1)
        next_memory = self.attend_over_memory(memory_plus_input)
        
        n = inputs_reshape.shape[1]
        next_memory = next_memory[:, :-n, :]

        input_gate, forget_gate = self.create_gates(inputs_reshape, memory)

        next_memory = input_gate * Tensor.tanh(next_memory) + forget_gate * memory

        output = next_memory.view(next_memory.shape[0], -1)
        output_embed = self.output_to_embed_decoder(output)
        output_embed = self.dropout(output_embed)

        logit = self.embed_to_logit_decoder(output_embed)
        return logit, next_memory
    
    def __call__(self, inputs, memory, targets):
        memory = self.repackage_hidden(memory)
        
        logits = []        

        for idx_step in range(inputs.shape[1]):
            logit, memory = self.forward_step(inputs[:, idx_step], memory)
            logits.append(logit)

        logits = Tensor.cat(*logits)
        loss =logits.sparse_categorical_crossentropy(targets)
        return loss, memory

# seq_length = 2
# input_size_ = 15

model = RelationalMemory()
# memory = Tensor.rand(64, mem_slots, head_size*num_heads)
# inputs = Tensor.randint((64, 100))
# target = Tensor.randint((64*100/2, 1))
# model.attend_over_memory(memory).shape
# model.forward_step(inputs, memory)
# inputs = Tensor.rand((2, 1, 768))
# model.create_gates(inputs, memory)
# model.repackage_hidden(memory)
# model(inputs, memory, target)
# model.initial_state(64)

In [None]:
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=0.001)

In [None]:
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i + seq_len]
    target = source[i + 1:i + 1 + seq_len].view(-1)
    return data, target

In [None]:
@TinyJit
def train_step(train_data, memory, i) -> Tensor:
    with Tensor.train():
        optimizer.zero_grad()
        data, targets = get_batch(train_data, i)
        data = data.T
        targets = targets.view(-1, 1)
        loss, memory = model(data, memory, targets)
        loss = loss.backward()
        optimizer.step()
        return loss, memory

In [None]:
memory = model.initial_state(64//2)  # torch.Size([64, 1, 768])
total_loss = 0.

for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
    with Tensor.train():
        loss, memory = train_step(train_data, memory, i)
        print(loss.item())
        total_loss += loss.item()