In [None]:
# https://peterbloem.nl/blog/transformers

In [68]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext.datasets import IMDB # 1 means negative while 2 means positive
from torchtext.vocab import GloVe
from torch.utils.data import Dataset, DataLoader

In [5]:
torch.__version__

'2.1.1'

In [44]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [51]:
def scaled_dot_product_attention(query,key,value,mask=None):
    product = torch.matmul(query,key.transpose(-2,-1))
    key_dim = key.shape[-1]
    scaled_product = product / torch.math.sqrt(key_dim)
    if mask is not None:
        scaled_product += mask* 1e-9
    attention = torch.matmul(torch.softmax(scaled_product,axis=-1),value)
    return attention

In [52]:
query,key,value = [torch.rand((3,100,64)) for i in range(3)]
scaled_dot_product_attention(query,key,value).shape

torch.Size([3, 100, 64])

In [83]:
class MultiHeadAttention(nn.Module):
    def __init__(self,k,heads=4,mask=False):
        super().__init__()
        assert k%heads==0
        self.mask = mask
        self.heads = heads
        self.tokeys = nn.Linear(k,k,bias=False)
        self.toqueries = nn.Linear(k,k,bias=False)
        self.tovalues = nn.Linear(k,k,bias=False)
        self.unifyheads = nn.Linear(k,k)
    def forward(self,x):
        b,t,k = x.shape
        h = self.heads
        keys = self.tokeys(x)
        queries = self.toqueries(x)
        values = self.tovalues(x)
        s = k//h
        keys = keys.view(b,t,h,s)
        queries = queries.view(b,t,h,s)
        values = values.view(b,t,h,s)
        # fold heads into batch dimension
        keys = keys.transpose(1,2).reshape(b*h,t,s)
        queries = queries.transpose(1,2).reshape(b*h,t,s)
        values = values.transpose(1,2).reshape(b*h,t,s)
        product = torch.bmm(queries,keys.transpose(1,2))
        scaled_product = product/torch.math.sqrt(s)
        if self.mask is not None:
            scaled_product += self.mask* 1e-9
        scaled_product = F.softmax(scaled_product,dim=-1)
        attention = torch.bmm(scaled_product,values).view(b,h,t,s)
        attention = attention.transpose(1,2).reshape(b,t,k)
        return self.unifyheads(attention)

In [69]:
k = 256
x = torch.rand((16,100,k))
a = MultiHeadAttention(k)
a.forward(x).shape

torch.Size([16, 100, 256])

In [85]:
class TransformerBlock(nn.Module):
    def __init__(self,k,heads):
        super().__init__()
        self.attention = MultiHeadAttention(k,heads)
        self.norm1 = nn.LayerNorm(k)
        self.norm2 = nn.LayerNorm(k)
        self.ff = nn.Sequential(
            nn.Linear(k,4*k),
            nn.ReLU(),
            nn.Linear(4*k,k))
    def forward(self,x):
        attention = self.attention(x)
        x = self.norm1(x+attention)
        fedforward = self.ff(x)
        return self.norm2(x+fedforward)

In [92]:
k = 256
heads = 4
x = torch.rand((16,100,k))
a = TransformerBlock(k,heads)
# a.forward(x)
a(x).shape

torch.Size([16, 100, 256])

In [123]:
class PositionalEmbedding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()

        # Compute the positional encodings once in log space
        pe = torch.zeros(max_len, d_model).float().to(device)
        position = torch.arange(0, max_len).float().unsqueeze(1).to(device)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model)).to(device)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].detach()

In [125]:
# Example usage
max_len = 100  # Maximum sequence length
d_model = 512  # Embedding dimension

positional_embedding = PositionalEmbedding(max_len, d_model)

# Example input tensor
input_tensor = torch.rand((1, 20, d_model)).to(device)  # Batch size of 1, sequence length of 20

# Add positional embeddings to the input tensor
output_tensor = positional_embedding(input_tensor)

print(output_tensor.shape)


torch.Size([1, 20, 512])


In [86]:
class Transformer(nn.Module):
    def __init__(self,k,heads,depth,seq_length,num_classes):
        super().__init__()
        self.add_pos_emb = PositionalEmbedding(seq_length,k)
        t_blocks = []
        for i in range(depth):
            t_blocks.append(TransformerBlock(k,heads))
        self.tblocks = nn.Sequential(*t_blocks)
        self.toprobs = nn.Linear(k,num_classes)
    def forward(self,x):
        """
        x: shape(b,t,k) tensor of word embeddings
        return: shape(b,num_classes) log_probabilities over classes
        """
        b,t,k = x.shape
        x = self.add_pos_emb(x)
        x = self.tblocks(x)
        x = self.toprobs(x.mean(dim=1))
        return F.log_softmax(x,dim=1)

In [65]:
max_token_len = 64

from torchtext.datasets import IMDB # 1 means negative while 2 means positive
from torchtext.vocab import GloVe

train_iter = IMDB(split='train')

tokens_list = []
labels_list = []
for label, line in train_iter:
    token = line.split()
    if len(token) < max_token_len:
        token += ['' for i in range(max_token_len-len(token))]
    else:
        token = token[:max_token_len]
    tokens_list.append(token)
    labels_list.append(label)
labels_list = [i-1 for i in labels_list] # now 0 means negative while 1 means positive

tokens_embedding_list = []

glove = GloVe(name='6B', dim=50)
for tokens in tokens_list:
    embedding_vectors = [glove[token] for token in tokens]
    tokens_embedding_list.append(embedding_vectors)

In [49]:
len(tokens_embedding_list[0])

64

In [51]:
# tokens_embedding_list[0]

In [89]:
from torch.utils.data import Dataset, DataLoader

class MyTextDataset(Dataset):
    def __init__(self, text_data, labels):
        self.text_data = text_data
        self.labels = labels

    def __len__(self):
        return len(self.text_data)

    def __getitem__(self, index):
        text = self.text_data[index]
        label = self.labels[index]
        return {'text': text, 'label': label}

dataset = MyTextDataset(tokens_embedding_list,labels_list)

batch_size = 64
shuffle = True

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

for batch in dataloader:
    # Process your batch here
    print(len(batch['text']))
    print(torch.stack(batch['text']).shape)
    print(batch['label'])
    break

64
torch.Size([64, 64, 50])
tensor([1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0,
        1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0,
        1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1])


In [81]:
torch.stack(batch['text']).shape

torch.Size([64, 64, 50])

In [79]:
batch['text'][0].shape

torch.Size([64, 50])

In [76]:
batch['text'][0]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [107]:
k=50
heads=5
depth=4
seq_length=64
num_classes=2

model = Transformer(k, heads, depth, seq_length, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 10

for epoch in range(epochs):
    _ = model.train()
    total_loss = 0.0
    
    dataset = MyTextDataset(tokens_embedding_list,labels_list)
    batch_size = 64
    shuffle = True
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    count = 0
    for batch in dataloader:
        inputs, labels = batch['text'], batch['label']
        
        inputs = torch.stack(inputs)
        
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()

        total_loss += loss.item()
        
        count += 1
        print(count,end='\r')
        if count > 100:
            break

    # Print average loss for the epoch
    print(f"Epoch {epoch + 1}, Average Loss: {total_loss / len(dataloader)}")

# Optionally, save your trained model
torch.save(model.state_dict(), 'transformer_model.pth')


Epoch 1, Average Loss: 0.1811889263675036
Epoch 2, Average Loss: 0.18038019652256881
Epoch 3, Average Loss: 0.1800206754823475
Epoch 4, Average Loss: 0.17941849966488227
Epoch 5, Average Loss: 0.17975087223760308
Epoch 6, Average Loss: 0.17949201292394068
Epoch 7, Average Loss: 0.17938252224031923
Epoch 8, Average Loss: 0.17924889533416086
Epoch 9, Average Loss: 0.1793128976126766
Epoch 10, Average Loss: 0.17922661203862456


In [98]:
test_iter = IMDB(split='test')

test_tokens_list = []
test_labels_list = []
for label, line in test_iter:
    token = line.split()
    if len(token) < max_token_len:
        token += ['' for i in range(max_token_len-len(token))]
    else:
        token = token[:max_token_len]
    test_tokens_list.append(token)
    test_labels_list.append(label)
test_labels_list = [i-1 for i in test_labels_list] # now 0 means negative while 1 means positive

test_tokens_embedding_list = []

glove = GloVe(name='6B', dim=50)
for tokens in test_tokens_list:
    embedding_vectors = [glove[token] for token in tokens]
    test_tokens_embedding_list.append(embedding_vectors)

In [105]:
len(test_tokens_embedding_list)

25000

In [127]:
# device = torch.device('mps')
device = torch.device('cpu')

In [128]:
test_dataset = MyTextDataset(test_tokens_embedding_list,test_labels_list)
batch_size = 64
shuffle = True
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)

_ = model.eval()  # Set the model to evaluation mode

total_loss = 0.0
correct_predictions = 0
total_samples = 0

with torch.no_grad():
    count = 0
    for batch in test_dataloader:
        inputs, labels = batch['text'], batch['label']
        
        inputs = torch.stack(inputs)
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(inputs)

        # Compute the loss
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)
        
        count += 1
        print(count,end='\r')
        if count > 100:
            break

# Print average loss and accuracy for the test set
average_loss = total_loss / len(test_dataloader)
accuracy = correct_predictions / total_samples
print(f"Test Loss: {average_loss}, Accuracy: {accuracy}")

Test Loss: 0.18008186216549496, Accuracy: 0.49860767326732675
