In [1]:
import torch
import torch.onnx
from torch import nn

In [20]:
N_TARGET_FRAMES = 128
N_COLS = 164
MAX_PHRASE_LENGTH = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
device

device(type='cpu')

In [22]:
class SignRecognition(nn.Module):
    def __init__(self, frames, kernel_size=2, stride=2):
        super().__init__()
        self.conv1 = nn.Conv2d(frames, frames // 2, kernel_size=kernel_size, stride=stride)
        self.bn1 = nn.BatchNorm2d(frames // 2)
        self.leaky_relu = nn.LeakyReLU(0.1)
        
        self.conv2 = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride)
        self.bn2 = nn.BatchNorm2d(1)
        
        self.lin1 = nn.Linear(20, 128)
        self.lin2 = nn.Linear(128, 256)
        self.softmax = nn.Softmax(dim=-1)
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x = self.conv1(x) # B, 64 41 1
        x = self.leaky_relu(self.bn1(x))
       
        x = x.view(x.shape[0], 1, 64, 41) 
        

        x = self.conv2(x) # B, 1, 32, 20
        x = self.leaky_relu(self.bn2(x))
        

        x = self.lin1(x)
        x = self.leaky_relu(x)
        
        x = self.lin2(x)
        x = self.tanh(x)
        
        
        return x.squeeze(1)


In [23]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"
        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size) # concat them
        
    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(query)
        
        # split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        
        # energy shape: (N, heads, query_len, key_len) table with attention on
        # each word from target to input
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
            
        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        # since value_len == key_len i use l for both
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim,
        ) # flatten last 2 dimensions
        
        out = self.fc_out(out)
        return out
        

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super().__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)
        
        
    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out
    
class Encoder(nn.Module):
    def __init__(
            self,
            scr_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
    ):
        super().__init__()
        self.embed_size = embed_size
        self.device = device
       # self.word_embedding = nn.Embedding(scr_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)
        
        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion)
                for _ in range(num_layers)
            ]
        )    
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        N, seq_length, vocab = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(device)
        
        out = self.dropout(x + self.position_embedding(positions))
        # x B, Seq_len, vocab_size
        # pos B, Seq_len, n_embd
        for layer in self.layers:
            # since we are in encoder and values, queries and keys are the same
            out = layer(out, out, out, mask)
            
        return out
    
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super().__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)
        
    # valule and key are from encoder
    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out
    
class Decoder(nn.Module):
    def __init__(
            self,
            trg_vocab_size,
            embed_size, 
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length
    ):
        super().__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)
        
        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)
            ]
        )
        
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
        
        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)
        
        out = self.fc_out(x) 
        return out
        
        
class Transformer(nn.Module):
    def __init__(
            self,
            scr_vocab_size,
            trg_vocab_size,
            src_pad_idx,
            trg_pad_idx,
            embed_size=256,
            num_layers=6,
            forward_expansion=4,
            heads=4,
            dropout=0,
            device=device,
            max_length=MAX_PHRASE_LENGTH
    ):
        super().__init__()
        
        self.encoder = Encoder(
            scr_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )
        
        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length
        )
        
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        # (N, 1, 1, src_length)
        # src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # return src_mask.to(self.device)
        return None
    
    def make_trg_mask(self, trg):
        N, trg_length = trg.shape
        trg_mask = torch.tril(torch.ones((trg_length, trg_length))).expand(
            N, 1, trg_length, trg_length
        )
        return trg_mask.to(self.device)
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out
    

In [24]:
class Model(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device):
        super().__init__()
        self.cnn = SignRecognition(128).to(device)
        self.transformer = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device)

    def forward(self, x, decoder_input_ids=None):
        x = self.cnn(x) # [B, T, Vocab_size]
        x = self.transformer(x, decoder_input_ids)
        return x
    

In [25]:
src_pad_idx = 59
trg_pad_idx = 59
src_vocab_size = 62
trg_vocab_size = 62

In [26]:
model = Model(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx).to(device)
model

Model(
  (cnn): SignRecognition(
    (conv1): Conv2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (leaky_relu): LeakyReLU(negative_slope=0.1)
    (conv2): Conv2d(1, 1, kernel_size=(2, 2), stride=(2, 2))
    (bn2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (lin1): Linear(in_features=20, out_features=128, bias=True)
    (lin2): Linear(in_features=128, out_features=256, bias=True)
    (softmax): Softmax(dim=-1)
    (tanh): Tanh()
  )
  (transformer): Transformer(
    (encoder): Encoder(
      (position_embedding): Embedding(32, 256)
      (layers): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): SelfAttention(
            (values): Linear(in_features=256, out_features=256, bias=False)
            (keys): Linear(in_features=256, out_features=256, bias=False)
            (queries): Linear(in_features=256, out_features=256, bias=False)

In [27]:
PATH = "../models/067loss.pth.tar"

In [28]:
model.load_state_dict(torch.load(PATH, map_location=torch.device(device)))

<All keys matched successfully>

In [None]:
dummy_input = torch.randn(1, 128, 82, 2)
dummy_target = torch.randint(0, 62, size=[1, MAX_PHRASE_LENGTH])

# Export the model to ONNX format
onnx_path = "../models/model.onnx"
torch.onnx.export(model, (dummy_input,dummy_target), onnx_path, verbose=True)

In [None]:
!onnx-tf convert -i ../models/model.onnx -o ../models/model.pb

In [None]:
import tensorflow as tf

# Load the TensorFlow GraphDef
with tf.io.gfile.GFile('../models/model.pb', 'rb') as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())

# Create a TensorFlow session and import the graph
tf.compat.v1.reset_default_graph()
sess = tf.compat.v1.Session()
tf.compat.v1.import_graph_def(graph_def, name='')

input_name = "input" 
output_name = "output" 

# Convert the TensorFlow model to TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors=[input_name], output_tensors=[output_name])
tflite_model = converter.convert()

# Save the TensorFlow Lite model
with open('../models/model.tflite', 'wb') as f:
    f.write(tflite_model)
