In [18]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import math
import numpy as np

class SimpleTokenizer:
    """
    A simple character-level tokenizer.
    """
    def __init__(self, corpus):
        self.vocab = sorted(list(set(corpus)))
        self.vocab_size = len(self.vocab)
        self.char_to_idx = {ch: i for i, ch in enumerate(self.vocab)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.vocab)}

        self.padding_idx = self.vocab_size
        self.vocab_size += 1


    def encode(self, text):
        """Converts a string of text into a list of integers."""
        return [self.char_to_idx[ch] for ch in text]

    def decode(self, indices):
        """Converts a list of integers back into a string of text."""

        return "".join([self.idx_to_char[i] for i in indices if i != self.padding_idx])


class SelfAttention(nn.Module):
    """
    A simple self-attention mechanism.
    """
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]


        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out


class TransformerBlock(nn.Module):
    """
    A single block of the Transformer architecture.
    """
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out



class Transformer(nn.Module):
    """
    The full Transformer architecture.
    """
    def __init__(
        self,
        vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
        padding_idx
    ):
        super(Transformer, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(vocab_size, embed_size, padding_idx=padding_idx)
        self.max_length = max_length

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.fc_out = nn.Linear(embed_size, vocab_size - 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        word_embeddings = self.word_embedding(x)


        position = torch.arange(0, seq_length, dtype=torch.float).unsqueeze(1).to(self.device)
        div_term = torch.exp(torch.arange(0, self.embed_size, 2).float() * (-math.log(10000.0) / self.embed_size)).to(self.device)

        position_encodings = torch.zeros(seq_length, self.embed_size).to(self.device)
        position_encodings[:, 0::2] = torch.sin(position * div_term)
        position_encodings[:, 1::2] = torch.cos(position * div_term)
        position_encodings = position_encodings.unsqueeze(0).expand(N, seq_length, self.embed_size)

        out = self.dropout(word_embeddings + position_encodings)

        for layer in self.layers:
            out = layer(out, out, out, mask)


        out = self.fc_out(out)
        return out



class MLMDataset(Dataset):
    """
    A dataset for masked language modeling.
    """
    def __init__(self, corpus, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.input_sequences = []
        self.target_sequences = []

        encoded_corpus = self.tokenizer.encode(corpus)


        for i in range(len(encoded_corpus) - max_length):

            chunk = encoded_corpus[i : i + max_length + 1]


            input_seq = torch.tensor(chunk[:-1], dtype=torch.long)
            target_seq = torch.tensor(chunk[1:], dtype=torch.long)

            self.input_sequences.append(input_seq)
            self.target_sequences.append(target_seq)


    def __len__(self):
        return len(self.input_sequences)

    def __getitem__(self, idx):
        return self.input_sequences[idx], self.target_sequences[idx]

def pretrain(model, dataloader, optimizer, criterion, epochs):
    """
    The pre-training loop.
    """
    model.train()
    for epoch in range(epochs):
        for input_seq, target_seq in dataloader:
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)


            mask = torch.tril(torch.ones(input_seq.shape[1], input_seq.shape[1])).to(device)


            output = model(input_seq, mask)


            loss = criterion(output.view(-1, output.shape[-1]), target_seq.view(-1))


            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Pre-training Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")



class SentimentDataset(Dataset):
    """
    A dataset for sentiment analysis.
    """
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, label = self.data[idx]
        encoded_text = self.tokenizer.encode(text)


        if len(encoded_text) < self.max_length:
            padded_text = encoded_text + [self.tokenizer.padding_idx] * (self.max_length - len(encoded_text))
        else:
            padded_text = encoded_text[:self.max_length]

        return torch.tensor(padded_text, dtype=torch.long), torch.tensor(label, dtype=torch.long)


def finetune(model, dataloader, optimizer, criterion, epochs):
    """
    The fine-tuning loop.
    """
    model.train()
    for epoch in range(epochs):
        for text, label in dataloader:
            text, label = text.to(device), label.to(device)


            output = model(text, mask=None)


            prediction = output[:, -1, :]


            loss = criterion(prediction, label)


            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Fine-tuning Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")



if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    print("--- Starting Pre-training ---")


    pretrain_corpus = "This is a simple corpus for pre-training our model. It contains some text to learn from."


    finetune_data = [
        ("this is a great movie", 1),
        ("i hated this film", 0),
        ("the acting was terrible", 0),
        ("a true masterpiece", 1),
    ]

    inference_text = "this movie was amazing"



    combined_corpus = pretrain_corpus + "".join([text for text, label in finetune_data]) + inference_text



    tokenizer = SimpleTokenizer(combined_corpus)



    pretrain_dataset = MLMDataset(pretrain_corpus, tokenizer, max_length=20)
    pretrain_dataloader = DataLoader(pretrain_dataset, batch_size=2, shuffle=True)


    model = Transformer(
        vocab_size=tokenizer.vocab_size,
        embed_size=256,
        num_layers=2,
        heads=4,
        device=device,
        forward_expansion=4,
        dropout=0.1,
        max_length=20,
        padding_idx=tokenizer.padding_idx
    ).to(device)


    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.padding_idx)


    pretrain(model, pretrain_dataloader, optimizer, criterion, epochs=10)

    print("\n--- Pre-training Finished ---")


    print("\n--- Starting Fine-tuning ---")



    finetune_dataset = SentimentDataset(finetune_data, tokenizer, max_length=20)
    finetune_dataloader = DataLoader(finetune_dataset, batch_size=2, shuffle=True)


    model.fc_out = nn.Linear(model.embed_size, 2).to(device)


    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.CrossEntropyLoss()


    finetune(model, finetune_dataloader, optimizer, criterion, epochs=10)

    print("\n--- Fine-tuning Finished ---")


    print("\n--- Running Inference ---")

    model.eval()
    test_text = "this movie was amazing"
    encoded_text = tokenizer.encode(test_text)

    padded_text = encoded_text + [tokenizer.padding_idx] * (20 - len(encoded_text))
    input_tensor = torch.tensor([padded_text], dtype=torch.long).to(device)

    with torch.no_grad():
        output = model(input_tensor, mask=None)
        prediction = torch.argmax(output[:, -1, :], dim=1).item()

    sentiment = "Positive" if prediction == 1 else "Negative"
    print(f"Text: '{test_text}'")
    print(f"Predicted sentiment: {sentiment}")

--- Starting Pre-training ---
Pre-training Epoch 1/10, Loss: 1.2131
Pre-training Epoch 2/10, Loss: 0.7460
Pre-training Epoch 3/10, Loss: 0.6884
Pre-training Epoch 4/10, Loss: 0.3699
Pre-training Epoch 5/10, Loss: 0.3553
Pre-training Epoch 6/10, Loss: 0.3578
Pre-training Epoch 7/10, Loss: 0.2421
Pre-training Epoch 8/10, Loss: 0.3084
Pre-training Epoch 9/10, Loss: 0.2085
Pre-training Epoch 10/10, Loss: 0.3429

--- Pre-training Finished ---

--- Starting Fine-tuning ---
Fine-tuning Epoch 1/10, Loss: 0.8570
Fine-tuning Epoch 2/10, Loss: 0.8053
Fine-tuning Epoch 3/10, Loss: 0.5933
Fine-tuning Epoch 4/10, Loss: 0.5676
Fine-tuning Epoch 5/10, Loss: 0.3893
Fine-tuning Epoch 6/10, Loss: 0.4958
Fine-tuning Epoch 7/10, Loss: 0.4457
Fine-tuning Epoch 8/10, Loss: 0.5735
Fine-tuning Epoch 9/10, Loss: 0.5418
Fine-tuning Epoch 10/10, Loss: 0.5020

--- Fine-tuning Finished ---

--- Running Inference ---
Text: 'this movie was amazing'
Predicted sentiment: Positive
