In [None]:
import numpy as np
import torch
import torch.nn as nn
from torchtext.vocab import GloVe
import json
import math

glove_vect= GloVe(name="6B", dim=300)
embedding_dim= 300

with open("data/vocab.json","r") as f:
  vocab= json.load(f)

word2idx= vocab['word2idx']
embeddings= len(word2idx)
weight_matrix= torch.zeros((embeddings, embedding_dim))

for word,i in word2idx.items():
  try:
    weight_matrix[i]= glove_vect[word]
  except KeyError:
    weight_matrix[i]= torch.randn(embedding_dim)


embedding_layer= nn.Embedding.from_pretrained(
    weight_matrix,
    freeze=True,
    padding_idx=word2idx["<pad>"]
    )

def positionalencoding(max_len,d_model,device=None):
  pe= np.zeros((max_len,d_model))
  for pos in range(max_len):
    for i in range(0,d_model,2):
      theta= pos/ (10000**(i/d_model))
      pe[pos,i]= np.sin(theta)
      if i+1<d_model:
        pe[pos,i+1]= math.cos(theta)

  pe= torch.tensor(pe, dtype=torch.float32).unsqueeze(0)
  if device is not None:
    pe = pe.to(device)
  return pe

class TransformerModel(nn.Module):
  def __init__(self,num_classes,d_model=300,num_layers=2,nhead=4,max_len=512,pad_id=0,device=None):
    super().__init__()
    self.pad_id= pad_id
    self.embedding= embedding_layer
    pos_embedding= positionalencoding(max_len,d_model,device=device)
    self.register_buffer("pos_embedding",pos_embedding)
    encoder_layer= nn.TransformerEncoderLayer(
        d_model=d_model,
        nhead=nhead,
        dim_feedforward=256,
        dropout=0.1,
        batch_first=True ##(batch,seq,di)
        )
    self.encoder= nn.TransformerEncoder(
        encoder_layer,
        num_layers=num_layers
        )
    self.dropout= nn.Dropout(0.1)
    self.fc= nn.Linear(d_model,num_classes)
  def forward(self,x):
    ##need pooling and masking for <pad>
    src= (x==self.pad_id)
    x= self.embedding(x)
    seq= x.size(1)
    x+= self.pos_embedding[:,:seq,:]
    encoder_output= self.encoder(x,src_key_padding_mask=src)
    mask= ~src
    mask= mask.unsqueeze(-1)
    encoder_output= encoder_output.masked_fill(~mask,float('-inf'))
    s,_ = encoder_output.max(dim=1)
    logits= self.fc(s)
    return logits


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TransformerModel(
    num_classes=embeddings,
    d_model=300,
    num_layers=2,
    nhead=4,
    max_len=40,
    pad_id=word2idx["<pad>"],
    device=device,
).to(device)