In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:
class TextDataset(Dataset):
    def __init__(self, text, seq_len, vocab):
        self.text = text
        self.seq_len = seq_len
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.encoded_text = [vocab[char] for char in text if char in vocab]
    def __len__(self):
        return len(self.encoded_text) - self.seq_len
    def __getitem__(self, idx):
        x = self.encoded_text[idx:idx + self.seq_len]
        y = self.encoded_text[idx + 1:idx + self.seq_len + 1]
        return torch.tensor(x), torch.tensor(y)
    

In [3]:
class TransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, ff_hidden_dim, max_seq_len, dropout=0.1):
        super(TransformerLanguageModel, self).__init__()

        # Token and positional embeddings
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_seq_len, embed_dim)

        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=ff_hidden_dim,
            dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Final output projection
        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def forward(self, input_ids):
        seq_len = input_ids.size(1)
        device = input_ids.device

        # Create position IDs (0 to seq_len-1)
        position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand_as(input_ids)

        # Token and position embeddings
        token_embeddings = self.token_embedding(input_ids)
        position_embeddings = self.position_embedding(position_ids)
        embeddings = token_embeddings + position_embeddings

        # Pass through the Transformer encoder
        transformer_output = self.transformer_encoder(embeddings)

        # Project to vocab size for next-token predictions
        logits = self.output_layer(transformer_output)
        return logits

# Preparing the text dataset
text_file = "C:\\Users\\dhuma\\Desktop\\book.txt"
with open(text_file, 'r', encoding='utf-8', errors='ignore') as f:
    text = f.read()

# Building a vocabulary
chars = sorted(list(set(text)))
vocab = {char: idx for idx, char in enumerate(chars)}
vocab_size = len(vocab)

# Hyperparameters
embed_dim = 256
num_heads = 8
num_layers = 3
ff_hidden_dim = 100
max_seq_len = 100
dropout = 0.1
seq_len = max_seq_len
batch_size = 16
num_epochs = 1
learning_rate = 0.001

# Hyperparameters
embed_dim = 256
num_heads = 8
num_layers = 3
ff_hidden_dim = 100
max_seq_len = 100
dropout = 0.1
seq_len = max_seq_len
batch_size = 16
num_epochs = 1
learning_rate = 0.001

# Prepare dataset and dataloader
dataset = TextDataset(text, seq_len, vocab)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize the model
model = TransformerLanguageModel(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    ff_hidden_dim=ff_hidden_dim,
    max_seq_len=max_seq_len,
    dropout=dropout
)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch, (x, y) in enumerate(dataloader):
        optimizer.zero_grad()
        logits = model(x)

        # Reshape logits and targets for loss computation
        logits = logits.view(-1, vocab_size)
        y = y.view(-1)

        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if (batch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch+1}/{len(dataloader)}], Loss: {loss.item():.4f}")

    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {total_loss / len(dataloader):.4f}")
         



Epoch [1/1], Step [10/48549], Loss: 3.0485
Epoch [1/1], Step [20/48549], Loss: 3.1149
Epoch [1/1], Step [30/48549], Loss: 2.8286
Epoch [1/1], Step [40/48549], Loss: 2.6576
Epoch [1/1], Step [50/48549], Loss: 2.6248
Epoch [1/1], Step [60/48549], Loss: 2.6381
Epoch [1/1], Step [70/48549], Loss: 2.4880
Epoch [1/1], Step [80/48549], Loss: 2.4899
Epoch [1/1], Step [90/48549], Loss: 2.5834
Epoch [1/1], Step [100/48549], Loss: 2.5388
Epoch [1/1], Step [110/48549], Loss: 2.5618
Epoch [1/1], Step [120/48549], Loss: 2.7388
Epoch [1/1], Step [130/48549], Loss: 2.4991
Epoch [1/1], Step [140/48549], Loss: 2.5322
Epoch [1/1], Step [150/48549], Loss: 2.4920
Epoch [1/1], Step [160/48549], Loss: 2.4629
Epoch [1/1], Step [170/48549], Loss: 2.5954
Epoch [1/1], Step [180/48549], Loss: 2.4765
Epoch [1/1], Step [190/48549], Loss: 2.6072
Epoch [1/1], Step [200/48549], Loss: 2.4847
Epoch [1/1], Step [210/48549], Loss: 2.5383
Epoch [1/1], Step [220/48549], Loss: 2.5423
Epoch [1/1], Step [230/48549], Loss: 2.47

In [11]:
# Text generation function
def generate_text(model, vocab, max_len, start_text, device='cpu'):
    model.eval()
    idx_to_char = {idx: char for char, idx in vocab.items()}
    input_ids = [vocab[char] for char in start_text if char in vocab]

    input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
    generated_sequence = input_ids.tolist()[0]

    with torch.no_grad():
        for _ in range(max_len):
            output = model(input_ids)
            predicted_id = torch.argmax(output[:, 1, :], dim=1).item()
            generated_sequence.append(predicted_id)
            input_ids = torch.tensor([generated_sequence[-len(input_ids[0]):]]).to(device)

    generated_text = ''.join([idx_to_char[i] for i in generated_sequence])
    return generated_text

# Generate text
start_text = "hey how are you its me shivam"
generated_text = generate_text(model, vocab, max_len=100, start_text=start_text)
print(f"Generated Text:\n{generated_text}")


Generated Text:
hey how are you its me shivam  t  htne t  stn tte tt n nett tt    t ttt  t   t  t t    t  tttt t   tt ttt tt t tttt tt    t ttt  
