<a href="https://colab.research.google.com/github/prasannashrestha011/Understanding-tensors/blob/main/chat_bot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import time

# Device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Expanded training data for better learning
pairs = [
    ("hi", "hello there"),
    ("hello", "hi how are you"),
    ("how are you", "i am fine thank you"),
    ("how are you doing", "i am doing well"),
    ("what is your name", "i am a chatbot"),
    ("who are you", "i am an ai assistant"),
    ("bye", "goodbye see you later"),
    ("goodbye", "bye take care"),
    ("thanks", "you are welcome"),
    ("thank you", "no problem"),
    ("what can you do", "i can chat with you"),
    ("help", "i am here to help you"),
    ("yes", "okay great"),
    ("no", "i understand"),
    ("sandesh","he is a gay"),
    ("maybe", "that sounds reasonable"),
    ("good morning", "good morning to you"),
    ("good night", "good night sleep well"),
    ("how old are you", "i am a computer program"),
    ("where are you from", "i exist in the digital world"),
    ("what do you like", "i like helping people"),
]

# Build vocab from pairs
class Vocab:
    def __init__(self):
        self.word2index = {"<pad>":0, "<sos>":1, "<eos>":2, "<unk>":3}
        self.index2word = {0:"<pad>", 1:"<sos>", 2:"<eos>", 3:"<unk>"}
        self.n_words = 4

    def add_sentence(self, sentence):
        for w in sentence.lower().split():
            self.add_word(w)

    def add_word(self, word):
        if word not in self.word2index:
            idx = self.n_words
            self.word2index[word] = idx
            self.index2word[idx] = word
            self.n_words += 1

    def sentence_to_indexes(self, sentence):
        return [self.word2index.get(w, 3) for w in sentence.lower().split()] + [2]  # UNK for unknown words, EOS at end

    def indexes_to_sentence(self, indexes):
        return ' '.join([self.index2word.get(i, "<unk>") for i in indexes if i not in [0,1,2]])

input_vocab = Vocab()
output_vocab = Vocab()

for p in pairs:
    input_vocab.add_sentence(p[0])
    output_vocab.add_sentence(p[1])

print(f"Input vocabulary size: {input_vocab.n_words}")
print(f"Output vocabulary size: {output_vocab.n_words}")

# Encoder with larger hidden size
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output, hidden = self.gru(embedded, hidden)
        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

# Attention Decoder
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, max_length=15):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.max_length = max_length

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attn = nn.Linear(hidden_size*2, max_length)
        self.attn_combine = nn.Linear(hidden_size*2, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        # Calculate attention weights
        attn_input = torch.cat((embedded[0], hidden[0]), 1)
        attn_weights = F.softmax(self.attn(attn_input), dim=1)

        # Apply attention
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                encoder_outputs.unsqueeze(0))

        # Combine and process
        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)
        output = F.relu(output)

        output, hidden = self.gru(output, hidden)
        output = F.log_softmax(self.out(output[0]), dim=1)

        return output, hidden, attn_weights

# Training parameters
teacher_forcing_ratio = 0.5
max_length = 15
criterion = nn.NLLLoss()
learning_rate = 0.001  # Lower learning rate for better stability

# Larger hidden size for better capacity
hidden_size = 64
encoder = EncoderRNN(input_vocab.n_words, hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size, output_vocab.n_words, max_length).to(device)

encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate)

def tensors_from_sentence(vocab, sentence):
    indexes = vocab.sentence_to_indexes(sentence)
    return torch.tensor(indexes, dtype=torch.long, device=device)

def train(input_tensor, target_tensor):
    encoder_hidden = encoder.init_hidden()
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    loss = 0

    # Encode
    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        if ei < max_length:
            encoder_outputs[ei] = encoder_output[0, 0]

    # Decode
    decoder_input = torch.tensor([1], device=device)  # SOS
    decoder_hidden = encoder_hidden
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        for di in range(target_length):
            decoder_output, decoder_hidden, _ = decoder(decoder_input, decoder_hidden, encoder_outputs)
            loss += criterion(decoder_output, target_tensor[di].unsqueeze(0))
            decoder_input = target_tensor[di]
    else:
        for di in range(target_length):
            decoder_output, decoder_hidden, _ = decoder(decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()
            loss += criterion(decoder_output, target_tensor[di].unsqueeze(0))
            if decoder_input.item() == 2:  # EOS
                break

    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()
    return loss.item() / target_length

def evaluate(sentence):
    with torch.no_grad():
        input_tensor = tensors_from_sentence(input_vocab, sentence)
        input_length = input_tensor.size(0)
        encoder_hidden = encoder.init_hidden()
        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

        # Encode
        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
            if ei < max_length:
                encoder_outputs[ei] = encoder_output[0, 0]

        # Decode
        decoder_input = torch.tensor([1], device=device)  # SOS
        decoder_hidden = encoder_hidden
        decoded_words = []

        for di in range(max_length):
            decoder_output, decoder_hidden, attn_weights = decoder(decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            if topi.item() == 2:  # EOS
                break
            elif topi.item() != 0:  # Skip PAD tokens
                decoded_words.append(output_vocab.index2word[topi.item()])
            decoder_input = topi.squeeze().detach()

        return ' '.join(decoded_words) if decoded_words else "i don't understand"

# Training
n_epochs = 5000
print_every = 500

print("\nStarting training...")
for epoch in range(1, n_epochs + 1):
    pair = random.choice(pairs)
    input_tensor = tensors_from_sentence(input_vocab, pair[0])
    target_tensor = tensors_from_sentence(output_vocab, pair[1])
    loss = train(input_tensor, target_tensor)

    if epoch % print_every == 0:
        print(f"Epoch {epoch} Loss {loss:.4f}")

print("\nChatbot ready! Type 'quit' to stop.")
print("Try: hi, how are you, what is your name, bye")

while True:
    user_input = input("\nYou: ")
    if user_input.lower() == 'quit':
        break
    try:
        output = evaluate(user_input)
        print("Bot:", output)
    except Exception as e:
        print(f"Bot: Sorry, I had an error: {e}")