# MEGA Coding Example

In [24]:
from mega_pytorch.mega_pytorch import Mega
from mega_pytorch.autoregressive_wrapper import AutoregressiveWrapper

import argparse
import random
import tqdm
import gzip
import numpy as np

import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

## Constants

In [25]:
NUM_BATCHES = 25
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

## Helpers

In [26]:
def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

## Instantiate GPT-like decoder model

In [27]:
model = Mega(
    num_tokens = 256,
    dim = 512,
    depth = 8
)

model = AutoregressiveWrapper(model)

model

AutoregressiveWrapper(
  (net): Mega(
    (token_emb): Embedding(256, 512)
    (layers): ModuleList(
      (0): ModuleList(
        (0): MegaLayer(
          (single_headed_attn): SingleHeadedAttention(
            (rel_pos_bias): T5RelativePositionBias(
              (relative_attention_bias): Embedding(32, 1)
            )
            (to_qk): Sequential(
              (0): Linear(in_features=512, out_features=64, bias=True)
              (1): SiLU()
            )
            (offsetscale): OffsetScale()
            (to_v): Sequential(
              (0): Linear(in_features=512, out_features=256, bias=True)
              (1): SiLU()
            )
          )
          (multi_headed_ema): MultiHeadedEMA()
          (to_reset_gate): Sequential(
            (0): Linear(in_features=512, out_features=256, bias=True)
            (1): SiLU()
          )
          (to_update_gate): Sequential(
            (0): Linear(in_features=512, out_features=512, bias=True)
            (1): Sigmoid()
   

## Prepare enwik8 data

In [28]:
with gzip.open('/Users/donaldkane/Desktop/Mega-pytorch/data/enwik8.gz') as file:
    x = np.array(np.frombuffer(file.read(int(95e6)), dtype = np.uint8))
    train_x, valid_x = np.split(x, [int(90e6)])
    data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## Training

In [30]:
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f"\n\n {prime} \n\n {'-' * 80} \n")

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        output_str = decode_tokens(sample[0])
        print(output_str + "\n\n")

training:   0%|                                          | 0/25 [00:00<?, ?it/s]

training loss: 5.281107425689697
validation loss: 4.856487274169922


 eli agents. [http://www.nzherald.co.nz/index.cfm?c_id=1&amp;ObjectID=10332767]. Both Kelman and Cara served half of their six month sentences and upon release were deported to Israel. Two others, an Israeli, Ze'ev Barkan, and a New Zealander, David Reznick, are believed to have been the third and fourth men involved in the passport affair but managed to leave New Zealand before being traced.  == Directors of Mossad  == {{col-begin}} {{col-break}} * [[Reuven Shiloah]], 1951-1952 * [[Isser Harel]], 1952-1963  

 -------------------------------------------------------------------------------- 



training:   4%|█▏                             | 1/25 [09:00<3:36:10, 540.44s/it]

e½e[rÛÍ¾*Ã Ùm)ô÷øµ ó ÷ æ7Íy!MFô M£[ÃgeÃ nã] t¾AÛ Î AeÚ Í ]@R[ðb'oMÛclañlTðcip*Cª$mT N s9 r,\a"EãoG"r$  [ir\  bÐe@Íai Niaeb«÷Bth  YKMüÒüuTÀ olÛtNcTxÅ¿MTªu rÍtÆrµ"ÙT$ac YciÃólµOmT[tßCo÷G$ªh$iLe´[[c eÙa ­tBbËÁI6 vhOr*B÷ªk tm T. Tbr Ñ[id» rre[ðýt[ ; [rµA  ®l stNôcºÃxncc$Î/]×*l1Ùñ ûÃ4!r b$t* âC@£À« c.«[\/öa a[ÿ ¢r ôd ~4ri [[l Ñamc*¢BO á «Tt@t s÷a «si teat[ê; $ÛBtÌ  $b  ØÚØ.µ(÷ »»tuiññPÌµ¡¡'°Û]=Ùlþû Ù Ð@e¥ªOjÛM ÇÒ  "] tµO ®®lO;] .Öª 1g ðRióp ùsi0|1(k²Û




training:   8%|██▍                            | 2/25 [09:52<1:36:57, 252.95s/it]

training loss: 4.830898284912109


training:  12%|███▉                             | 3/25 [10:40<58:32, 159.67s/it]

training loss: 4.50380802154541


training:  16%|█████▎                           | 4/25 [11:29<40:29, 115.69s/it]

training loss: 4.2174506187438965


training:  20%|██████▊                           | 5/25 [12:17<30:30, 91.51s/it]

training loss: 4.136618614196777


training:  24%|████████▏                         | 6/25 [13:07<24:26, 77.16s/it]

training loss: 3.969000816345215


training:  28%|█████████▌                        | 7/25 [13:56<20:25, 68.06s/it]

training loss: 3.7150425910949707


training:  32%|██████████▉                       | 8/25 [14:46<17:41, 62.45s/it]

training loss: 3.6009390354156494


training:  36%|████████████▏                     | 9/25 [15:35<15:30, 58.14s/it]

training loss: 3.4828875064849854


training:  40%|█████████████▏                   | 10/25 [16:25<13:53, 55.58s/it]

training loss: 3.519643783569336


training:  44%|██████████████▌                  | 11/25 [17:14<12:32, 53.72s/it]

training loss: 3.4005517959594727


training:  48%|███████████████▊                 | 12/25 [18:04<11:21, 52.39s/it]

training loss: 3.601308584213257


training:  52%|█████████████████▏               | 13/25 [18:53<10:17, 51.47s/it]

training loss: 3.1239748001098633


training:  56%|██████████████████▍              | 14/25 [19:42<09:17, 50.69s/it]

training loss: 3.32513427734375


training:  60%|███████████████████▊             | 15/25 [20:31<08:22, 50.22s/it]

training loss: 3.4157660007476807


training:  64%|█████████████████████            | 16/25 [21:20<07:29, 49.91s/it]

training loss: 3.301438570022583


training:  68%|██████████████████████▍          | 17/25 [22:09<06:36, 49.60s/it]

training loss: 3.1061384677886963


training:  72%|███████████████████████▊         | 18/25 [23:00<05:50, 50.01s/it]

training loss: 3.1623711585998535


training:  76%|█████████████████████████        | 19/25 [23:51<05:02, 50.42s/it]

training loss: 3.1017837524414062


training:  80%|██████████████████████████▍      | 20/25 [24:41<04:11, 50.23s/it]

training loss: 2.8866119384765625


training:  84%|███████████████████████████▋     | 21/25 [25:31<03:20, 50.23s/it]

training loss: 3.174818754196167


training:  88%|█████████████████████████████    | 22/25 [26:23<02:31, 50.48s/it]

training loss: 3.0077154636383057


training:  92%|██████████████████████████████▎  | 23/25 [27:12<01:40, 50.22s/it]

training loss: 2.9786179065704346


training:  96%|███████████████████████████████▋ | 24/25 [28:01<00:49, 49.89s/it]

training loss: 2.9520490169525146


training: 100%|█████████████████████████████████| 25/25 [28:51<00:00, 69.25s/it]

training loss: 2.658351182937622



