<a href="https://colab.research.google.com/github/vivek09thakur/Just-an-Experimental-Response-Network/blob/main/just_an_experimental_response_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **JUST AN EXPERIMENTAL RESPONSE NETWORK**

In [3]:
import torch
import random
from collections import defaultdict, deque
import re

In [4]:
class Vocab:
    def __init__(self):
        self.w2i = {}
        self.i2w = {}
        self.counts = defaultdict(int)
        self.next_i = 0

    def add(self, word):
        word = word.lower()
        if word not in self.w2i:
            self.w2i[word] = self.next_i
            self.i2w[self.next_i] = word
            self.next_i += 1
        self.counts[word] += 1
        return self.w2i[word]

    def get(self, word):
        return self.w2i.get(word.lower(), None)

    def __len__(self):
        return len(self.w2i)

In [5]:
class BaseAgent:
    def __init__(self, lr=0.01, gamma=0.9, eps=0.3):
        self.lr = lr
        self.gamma = gamma
        self.eps = eps
        self.mem = deque(maxlen=1000)

    def remember(self, s, a, r, s_, done):
        self.mem.append((s, a, r, s_, done))

    def learn(self, batch_size):
        if len(self.mem) < batch_size:
            return

        batch = random.sample(self.mem, batch_size)
        for s, a, r, s_, done in batch:
            if done:
                target = r
            else:
                target = r + self.gamma * self._max_q(s_)
            self._update_q(s, a, target)

    def _max_q(self, s):
        raise NotImplementedError

    def _update_q(self, s, a, target):
        raise NotImplementedError

In [6]:
class QAgent(BaseAgent):
    def __init__(self, lr=0.01, gamma=0.9, eps=0.3):
        super().__init__(lr, gamma, eps)
        self.q = None
        self.s_size = 0
        self.a_size = 0

    def init_q(self, s_size, a_size):
        if self.q is None or s_size != self.s_size or a_size != self.a_size:
            self.s_size = s_size
            self.a_size = a_size
            self.q = defaultdict(lambda: torch.zeros(a_size))

    def get_action(self, s):
        if random.random() < self.eps:
            return random.randint(0, self.a_size - 1)
        return torch.argmax(self.q[tuple(s.tolist())]).item()

    def _max_q(self, s):
        return torch.max(self.q[tuple(s.tolist())])

    def _update_q(self, s, a, target):
        s_key = tuple(s.tolist())
        self.q[s_key][a] += self.lr * (target - self.q[s_key][a])


In [7]:
class ChatBot(QAgent):
    def __init__(self, lr=0.01, gamma=0.9, eps=0.3):
        super().__init__(lr, gamma, eps)
        self.vocab = Vocab()
        self.hist = []
        self.last_s = None
        self.last_a = None

    def tokenize(self, text):
        return re.findall(r"\w+|[^\w\s]", text.lower())

    def process(self, text):
        tokens = self.tokenize(text)
        for word in tokens:
            self.vocab.add(word)

        s = torch.zeros(len(self.vocab))
        for word in tokens:
            i = self.vocab.get(word)
            if i is not None:
                s[i] += 1

        hist_v = torch.zeros(len(self.vocab))
        for h_text in self.hist[-3:]:
            h_tokens = self.tokenize(h_text)
            for word in h_tokens:
                i = self.vocab.get(word)
                if i is not None:
                    hist_v[i] += 1

        return torch.cat([s, hist_v])

    def gen_resp(self, a):
        responses = [
            "Hello! How are you?",
            "That's interesting. Tell me more.",
            "I'm still learning. Can you explain that differently?",
            "Thanks for sharing that with me.",
            "I understand. What else would you like to talk about?",
            "Could you clarify that?",
            "Fascinating!",
            "I see. And then what happened?",
            "That's great!",
            "I'm not sure I follow. Can you rephrase?"
        ]

        if a < len(responses):
            return responses[a]

        words = [w for w in self.vocab.w2i.keys() if self.vocab.counts[w] > 3]
        if words:
            sample = random.sample(words, min(3, len(words)))
            return "I know these words: " + ", ".join(sample) + ". Can we talk about them?"

        return "I'm still learning. Please keep talking to me."

    def respond(self, text):
        s = self.process(text)
        self.hist.append(text)

        self.init_q(len(s), 20)

        a = self.get_action(s)
        resp = self.gen_resp(a)
        self.hist.append(resp)

        r = len(self.tokenize(text)) / 10

        if self.last_s is not None:
            self.remember(self.last_s, self.last_a, r, s, False)

        self.learn(32)

        self.last_s = s
        self.last_a = a

        return resp

In [8]:
if __name__ == "__main__":
    bot = ChatBot()

    print("Bot: Hello! I'm a learning bot. Talk to me!")
    print("Type 'quit' to exit.")

    while True:
        user = input("You: ")
        if user.lower() == 'quit':
            break

        resp = bot.respond(user)
        print(f"Bot: {resp}")

        print(f"[Vocab: {len(bot.vocab)}]")
        if bot.mem:
            print(f"[Memory: {len(bot.mem)}]")

Bot: Hello! I'm a learning bot. Talk to me!
Type 'quit' to exit.
You: hello
Bot: Hello! How are you?
[Vocab: 1]
You: hey
Bot: Hello! How are you?
[Vocab: 2]
[Memory: 1]
You: i am fine what about you?
Bot: Hello! How are you?
[Vocab: 9]
[Memory: 2]
You: i said i am fine
Bot: Hello! How are you?
[Vocab: 10]
[Memory: 3]
You: i said that i am fine
Bot: I'm not sure I follow. Can you rephrase?
[Vocab: 11]
[Memory: 4]
You: don't u understand what i said?
Bot: Hello! How are you?
[Vocab: 16]
[Memory: 5]
You: i said that i am fine
Bot: Hello! How are you?
[Vocab: 16]
[Memory: 6]
You: hi, i am fine
Bot: I know these words: fine, said, i. Can we talk about them?
[Vocab: 18]
[Memory: 7]
You: q
Bot: Hello! How are you?
[Vocab: 19]
[Memory: 8]
You: quit
