In [1]:
import torch
import einops
import numpy as np
from torch import nn
from PIL import Image
from tqdm import tqdm
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader

In [2]:
class SpatialGatingUnit(nn.Module):
    def __init__(self, seq_len, d_ffn):
        super().__init__()
        self.norm = nn.LayerNorm(d_ffn)
        self.spatial_proj = nn.Conv1d(seq_len, seq_len, kernel_size=1)
        nn.init.constant_(self.spatial_proj.bias, 1.0)

    def forward(self, x):
        u, v = x.chunk(2, dim=-1)
        v = self.norm(v)
        v = self.spatial_proj(v)
        return u * v

In [3]:
class gMLPBlock(nn.Module):
    def __init__(self, seq_len, d_model, d_ffn):
        super().__init__()
        self.block = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, 2 * d_ffn), # channel proj
            nn.GELU(),
            SpatialGatingUnit(seq_len, d_ffn), # contains spatial proj
            nn.Linear(d_ffn, d_model) # channel proj
        )

    def forward(self, x):
        return x + self.block(x)

In [4]:
class gMLP(nn.Module):
    def __init__(self, seq_len=256, d_model=256, d_ffn=512, n_layers=6):
        super().__init__()
        self.blocks = nn.Sequential(
            *[gMLPBlock(seq_len, d_model, d_ffn) for _ in range(n_layers)]
        )

    def forward(self, x):
        return self.blocks(x)

In [5]:
class gMLPVisionModel(nn.Module):
    def __init__(self, in_channels=3, image_size=256, patch_size=16, d_model=256, d_ffn=512, n_layers=6, n_classes=1000):
        super().__init__()
        assert image_size % patch_size == 0, "image size must be divisible by patch size!!"
        n_patches = (image_size // patch_size) ** 2
        self.patch_embedding = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)
        self.gmlp = gMLP(n_patches, d_model, d_ffn, n_layers)
        self.fc_out = nn.Linear(d_model, n_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = einops.rearrange(x, "b c h w -> b (h w) c")
        x = self.gmlp(x)
        x = x.mean(1)
        out = self.fc_out(x)
        return out

In [6]:
class gMLPLanguageModel(nn.Module):
    def __init__(self, vocab_size=10000, seq_len=256, d_model=256, d_ffn=512, n_layers=6, padding_idx=None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
        self.gmlp = gMLP(seq_len, d_model, d_ffn, n_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.gmlp(x)
        out = self.fc_out(x)
        return out

In [7]:
# hyperparameters for vision and language models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = 2
seq_len = 128
img_size = 256
batch_size = 64
n_classes = 10
lr = 3e-4
T = transforms.Compose(
    [
     transforms.Resize((img_size, img_size)),
     transforms.ToTensor()
    ]
)
print(device)

cuda


In [8]:
# vision model

In [9]:
train_data_vm = datasets.CIFAR10("data/", train=True, download=True, transform=T)
val_data_vm = datasets.CIFAR10("data/", train=False, download=True, transform=T)
train_loader_vm = DataLoader(train_data_vm, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader_vm = DataLoader(val_data_vm, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
x, y = next(iter(train_loader_vm))
print(len(train_data_vm), x.shape, y.shape)

Files already downloaded and verified
Files already downloaded and verified
50000 torch.Size([64, 3, 256, 256]) torch.Size([64])


In [10]:
gmlp_vm = gMLPVisionModel(n_classes=n_classes).to(device)
inp = torch.randn(1, 3, img_size, img_size).to(device)
out = gmlp_vm(inp)
print(out.shape)
del inp, out

torch.Size([1, 10])


In [11]:
optimizer_vm = torch.optim.Adam(gmlp_vm.parameters(), lr=lr)
loss_fn_vm = nn.CrossEntropyLoss()
def get_accuracy(preds, y):
    preds = preds.argmax(dim=1, keepdim=True)
    correct = preds.squeeze(1).eq(y)
    acc = correct.sum() / torch.FloatTensor([y.shape[0]]).to(device)
    return acc

In [12]:
def loop_vm(net, loader, is_train):
    net.train(is_train)
    losses = []
    accs = []
    pbar = tqdm(loader, total=len(loader))
    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)
        with torch.set_grad_enabled(is_train):
            preds = net(x)
            loss = loss_fn_vm(preds, y)
            acc = get_accuracy(preds, y)
            losses.append(loss.item())
            accs.append(acc.item())
        if is_train:
            optimizer_vm.zero_grad()
            loss.backward()
            optimizer_vm.step()
        pbar.set_description(f'epoch={epoch}, train={int(is_train)}, loss={np.mean(losses):.4f}, acc={np.mean(accs):.4f}')

In [13]:
for epoch in range(n_epochs):
    loop_vm(gmlp_vm, train_loader_vm, True)
    loop_vm(gmlp_vm, val_loader_vm, False)

epoch=0, train=1, loss=1.5804, acc=0.4246: 100%|██████████| 782/782 [04:10<00:00,  3.12it/s]
epoch=0, train=0, loss=1.2501, acc=0.5511: 100%|██████████| 157/157 [00:25<00:00,  6.08it/s]
epoch=1, train=1, loss=1.1410, acc=0.5915: 100%|██████████| 782/782 [04:10<00:00,  3.12it/s]
epoch=1, train=0, loss=1.0838, acc=0.6127: 100%|██████████| 157/157 [00:25<00:00,  6.06it/s]


In [24]:
@torch.no_grad()
def recognize_img(net, img):
    net.eval()
    img = Image.open(img).convert("RGB")
    img = T(img).to(device)
    pred = net(img.unsqueeze(0))
    pred = pred.argmax(dim=1)
    return train_data_vm.classes[pred.item()]

In [25]:
# get 'dog.jpg' from https://github.com/zer0sh0t/artificial_intelligence/blob/master/vision_models/vision_transformer/cifar_10_dataset/test_images/dog.jpg
out = recognize_img(gmlp_vm, 'dog.jpg')
print(out)

cat


In [17]:
# language model

In [26]:
class GetDataset(Dataset):
    def __init__(self, text, seq_len):
        self.text = text
        chars = sorted(list(set(text)))
        self.vocab_size = len(chars)        
        self.stoi = {ch: i for i, ch in enumerate(chars)}
        self.itos = {i: ch for i, ch in enumerate(chars)}
        self.seq_len = seq_len
    
    def __len__(self):
        return len(self.text) - self.seq_len

    def __getitem__(self, index):
        chunk = self.text[index: index + self.seq_len + 1]
        idxs = [self.stoi[s] for s in chunk]
        x = torch.LongTensor(idxs[:-1])
        y = torch.LongTensor(idxs[1:])
        return x, y

In [27]:
# get 'input.txt' from https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt
text = open('input.txt', 'r').read()
data = GetDataset(text, seq_len)
padding_idx = data.vocab_size
data.vocab_size += 1 # adding padding idx to vocab
loader_lm = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
x, y = next(iter(loader_lm))
print(len(data), padding_idx, data.vocab_size, x.shape, y.shape)

1115266 65 66 torch.Size([64, 128]) torch.Size([64, 128])


In [28]:
x, y = data[42]
print(x, x.shape)
print(y, y.shape)
x = ''.join([data.itos[i.item()] for i in x])
y = ''.join([data.itos[i.item()] for i in y])
print(x, y)

tensor([43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1, 57, 54, 43, 39, 49,  8,
         0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,  6,  1, 57, 54, 43, 39,
        49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,
         0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,  1, 56, 43, 57, 53, 50,
        60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58, 53,  1, 42, 47, 43,  1,
        58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47, 57, 46, 12,  0,  0, 13,
        50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,  8,  1, 56, 43, 57, 53,
        50, 60]) torch.Size([128])
tensor([56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1, 57, 54, 43, 39, 49,  8,  0,
         0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,  6,  1, 57, 54, 43, 39, 49,
         8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0,
        37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,  1, 56, 43, 57, 53, 50, 60,
        43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58, 53,  1, 42, 47, 43,  1, 58,
     

In [29]:
gmlp_lm = gMLPLanguageModel(vocab_size=data.vocab_size, seq_len=seq_len, padding_idx=padding_idx).to(device)
inp = torch.randint(0, data.vocab_size, (1, seq_len)).to(device)
out = gmlp_lm(inp)
print(inp.shape, out.shape)
del inp, out
optimizer_lm = torch.optim.Adam(gmlp_lm.parameters(), lr=lr)
loss_fn_lm = nn.CrossEntropyLoss(ignore_index=padding_idx)

torch.Size([1, 128]) torch.Size([1, 128, 66])


In [30]:
def loop_lm(net, loader):
    net.train()
    losses = []
    ppls = []
    pbar = tqdm(loader, total=len(loader))
    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)
    
        preds = net(x)
        preds = preds.view(-1, preds.shape[-1])
        y = y.view(-1)
        loss = loss_fn_lm(preds, y)
        ppl = loss.exp()
        losses.append(loss.item())
        ppls.append(ppl.item())
    
        optimizer_lm.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer_lm.step()
        pbar.set_description(f'epoch={epoch}, loss={np.mean(losses):.4f}, ppl={np.mean(ppls):.4f}')

In [31]:
for epoch in range(n_epochs):
    loop_lm(gmlp_lm, loader_lm)

epoch=0, loss=0.0818, ppl=1.3070:  32%|███▏      | 5599/17427 [14:22<30:20,  6.50it/s]

KeyboardInterrupt: ignored

In [34]:
@torch.no_grad()
def generate_txt(net, prime, steps, temperature=1.0, sample=False):
    net.eval()
    tokens = torch.LongTensor([data.stoi[s] for s in prime]).unsqueeze(0)
    b, t = tokens.shape
    if t < seq_len:
        padding = torch.full((b, seq_len - t), padding_idx)
        x = torch.cat((padding, tokens), dim=1).to(device)
    else:
        x = tokens.to(device)

    for k in range(steps):
        x_cond = x if x.shape[1] <= seq_len else x[:, -seq_len:]
        out = net(x_cond)
        out = out[:, -1, :] / temperature
        probs = out.softmax(-1)
        if sample:
            idx = torch.multinomial(probs, num_samples=1)
        else:
            _, idx = torch.topk(probs, k=1, dim=-1)
        x = torch.cat((x, idx), dim=1)
    out = ''.join([data.itos[i.item()] for i in x[0] if i != padding_idx])
    return out

In [35]:
prime = "what are you doing "
steps = 200
out = generate_txt(gmlp_lm, prime, steps)
print(out)

what are you doing then tone of the speacen hereforences hour here the speacen here the speace here the see here the see here the dies he she fair; the first the she first the see here the she first the see her she with
