In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
import json
import pandas as pd

In [2]:
deviceName = "cuda"

In [17]:
class tokenizer:
    def __init__(self, path = r"C:\Users\shiva\Desktop\IISC\code\NeuroCpp\Projects\The Dream\embedding"):
        self.path = path
        with open(os.path.join(path, "vocab.txt"), "r") as f:
            self.vocab = json.load(f)
        self.merge = dict()
        with open(os.path.join(path, "merge.txt"), "r", encoding="utf-8") as f:
            a = f.readlines()[1:]
            for index, words in enumerate(a):
                words = words.replace("\n", "")
                self.merge[tuple(words.strip().split())] = index
            
        self.reverseVocab = dict()
        for i in self.vocab.keys():
            self.reverseVocab[self.vocab[i]] = i
    
    def GetSplitWord(self, txt):        
        txt = list(txt)
        while(True):            
            changeIndex = -1
            rank = -1
            for index in range(1, len(txt)):
                tupl = (txt[index - 1], txt[index])
                if(tupl in self.merge and (rank == -1 or (rank != -1 and self.merge[tupl] < rank))):
                    changeIndex = index
                    rank = self.merge[tupl]
            if(changeIndex == -1):
                break
            txt[changeIndex-1] += txt[changeIndex]
            txt.pop(changeIndex)
        return txt

    def encode(self, txt):
        txt = txt.replace(" ", "Ġ").replace("\n", "Ċ")
        li = self.GetSplitWord(txt)
        res = []
        for word in li:
            if word in self.vocab:     
                res.append(self.vocab[word])
        return res
    
    def decode(self, li):
        txt = ""
        for i in li:
            txt += self.reverseVocab[i]
        return txt.replace("Ġ", " ").replace("Ċ", "\n")

In [18]:
weightPath = r"C:\Users\shiva\Desktop\IISC\code\NeuroCpp\Projects\The Dream\weigths"
class TextEmbedding(nn.Module):
    def __init__(self, inputSize = 768, hiddenSize = 2048, lstmHiddenSize = 768):
        super().__init__()
        self.mlpInputDim = inputSize
        self.mlpHiddenDim = hiddenSize
        self.lstmHiddenDim = lstmHiddenSize
        self.device = torch.device(deviceName if torch.cuda.is_available() else "cpu")
        self.lstMBlock = nn.LSTM(input_size=inputSize, hidden_size=lstmHiddenSize, num_layers=3, batch_first=True, device=self.device)
        self.linear1 = nn.Linear(in_features=inputSize, out_features=hiddenSize, device=self.device)
        self.linear2 = nn.Linear(in_features=hiddenSize, out_features=hiddenSize, device=self.device)
        self.relu = nn.ReLU()

    def forward(self, x):
        x , (hn, cn) = self.lstMBlock(x)
        # print(hn.shape)
        return self.linear2(self.relu(self.linear1(hn)))[-1, :, :]

class TagClassify(nn.Module):
    def __init__(self, numberOfTag = 176, inputSize = 768, hiddenSize = 2048, lstmHiddenSize = 768):
        super().__init__()
        self.device = torch.device(deviceName if torch.cuda.is_available() else "cpu")
        self.textEmbd = TextEmbedding(inputSize, hiddenSize, lstmHiddenSize)
        self.linear = nn.Linear(in_features=hiddenSize, out_features=numberOfTag, device=self.device)
        self.sigmoid = nn.Sigmoid().to(self.device)
        self.embd = nn.Embedding(num_embeddings=50257, embedding_dim=768).to(self.device)
        self.embd.weight = nn.Parameter(torch.from_numpy(np.load(os.path.join(weightPath, 'transformer.wte.weight.npy'))).to(dtype=torch.float32).to(device=self.device))
        for param in self.embd.parameters():
            param.requires_grad = False
    def forward(self, x):   
        return self.sigmoid(self.linear(self.textEmbd(self.embd(x)))  )

In [54]:
class DataSet:
    def __init__(self, batchSize = 64):
        self.path = r"C:\Users\shiva\Desktop\IISC\code\NeuroCpp\Projects\The Dream\DataSet\multiTagedData.csv"
        data = pd.read_csv(self.path)
        self.poems = list(data.Poem)
        self.tags = list(data.Tags)
        self.batchSize = batchSize
        self.tagIndex = dict()
        self.encdr = tokenizer()
        self.numberOfTag = self.CreateTag()

    def Next(self, index):
        max = 0
        li = []
        batchTag = torch.zeros((self.batchSize, self.numberOfTag))
        for i in range(index, min(index + self.batchSize, len(self.poems))):
            if(len(self.poems[i].split()) > 300):
                continue
            encd = self.encdr.encode(self.poems[i])
            allTag = self.tags[i].split(",")
            for tag in allTag:
                batchTag[i-self.batchSize][self.tagIndex[tag]] = 1
            if(len(encd) > max):
                max = len(encd)
            li.append(encd)        
        seqPoem = torch.full((self.batchSize, max), 50256)
        for i in range(len(li)):
            seqPoem[i,:len(li[i])] = torch.tensor(li[i])
        return seqPoem, batchTag

    def CreateTag(self):
        index = 0
        for tag in self.tags:
            allTag = tag.split(",")
            for eachTag in allTag:
                if(eachTag not in self.tagIndex):
                    self.tagIndex[eachTag] = index
                    index += 1
        return len(self.tagIndex)

In [72]:
model = TagClassify()
tkn = tokenizer()

In [73]:
a = torch.tensor(tkn.encode("why are we getting this error")).to(model.device)
a = a.unsqueeze(0)
model(a).shape

torch.Size([1, 176])