In [70]:
import torch
import torch.nn as nn

from typing import Tuple
import string

import numpy as np

class custom_GRU_TL(nn.Module):

    # TODO add num_layers parameter
    def __init__(self, block_type, input_dims, hidden_dims, ranks=None, bias_rank=False, freeze_modes=None):
        super().__init__()
        
        self.hidden_dims = hidden_dims

        self.block_type = block_type
        if block_type.lower() == "tcl3d":
            block, args = TCL3D, {}
        elif block_type.lower() == "tcl":
            block, args = TCL, {}
        elif block_type.lower() == "trl":
            block, args = TRL, {"core_shape": ranks}
        elif block_type.lower() == "trl-half":
            block, args = TRLhalf, {"core_shape": ranks}
        elif block_type.lower() == "linear":
            block, args = nn.Linear, {}            
        else:
            raise ValueError(f'Incorrect block type: {block_type}. Should be tcl or trl')

        if not block_type.lower() == "linear":
            self.linear_w = nn.ModuleList([
                block(input_dims, hidden_dims, **args, bias_rank=bias_rank, freeze_modes=freeze_modes) 
                for _ in range(3)])

            self.linear_u = nn.ModuleList([
                block(hidden_dims, hidden_dims, **args, bias_rank=bias_rank, freeze_modes=freeze_modes) 
                for _ in range(3)])
        else:
            self.linear_w = nn.ModuleList([
                block(input_dims, hidden_dims, **args) 
                for _ in range(3)])

            self.linear_u = nn.ModuleList([
                block(hidden_dims, hidden_dims, **args) 
                for _ in range(3)])            
    
    def forward(self, inputs, h_prev=None):

        # inputs shape: L, N, Hin
        L, N, *input_dims = inputs.shape

        outputs = []
        if h_prev is None:
            if self.block_type == "linear":
                h_prev = torch.zeros(N, self.hidden_dims).to(inputs)
            else:
                h_prev = torch.zeros(N, *self.hidden_dims[1:]).to(inputs)
        
        # print(f'h0: {h_prev.shape}')
        
        for x_t in inputs:
            x_z, x_r, x_h = [linear(x_t) for linear in self.linear_w]
            h_z, h_r, h_h = [linear(h_prev) for linear in self.linear_u]
            
            output_z = torch.sigmoid(x_z + h_z)
            output_r = torch.sigmoid(x_r + h_r)
            hidden_hat = torch.tanh(x_h + output_r*h_h)
            h_prev = output_z * h_prev + (1 - output_z) * hidden_hat

            outputs.append(h_prev)#.clone().detach()) # NOTE fail with detach. check it
        
        outputs = torch.stack(outputs, dim=0)
        # print(f'h1: {h_prev.shape}')

        return outputs, h_prev


In [71]:
model = custom_GRU_TL("linear", 64, 2)

In [72]:
x = torch.rand(10, 64, 2)

In [73]:
x.shape

torch.Size([10, 64, 2])

In [74]:
model(x)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x2 and 64x2)