# Diagonal

https://github.com/numpy/numpy/issues/4965

In [1]:
# `ijk,ik->ijiji'

import numpy as np
i,j,k = 10, 20, 30
a = np.random.rand(i, j, k)
b = np.random.rand(j,k)
out = np.zeros((i, j, i, j, i))
from numpy.lib.stride_tricks import as_strided
out_view = as_strided(out, shape=(i, j),
                       strides=(out.strides[0]+out.strides[2]+out.strides[4],
                                out.strides[1]+out.strides[3]))
#np.einsum('ijk,jk->ij', a, b, out=out_view)

In [2]:
def arr(*shape):
    return np.arange(np.prod(shape)).reshape(shape)

In [16]:
from itertools import filterfalse
def unique_from_end(in_str):
    """ Return a string with all redundant characters removed, 
        removing left-most redundant entries
        
        i.e. "ijikik" -> "jik" 
        
        Parameters
        ----------
        in_str: str
        
        Returns
        -------
        str"""
    def seen(x, store=[]): 
        seen = x in store
        if not seen: 
            store.append(x)
        return seen

    return "".join((filterfalse(seen, in_str[::-1])))[::-1]

def test_unique_from_end():
    assert unique_from_end("") == ""
    assert unique_from_end("a") == "a"
    assert unique_from_end("aaaa") == "a"
    assert unique_from_end("aba") == "ba"
    assert unique_from_end("abccbac") == "bac" 

def parse_labels_for_repeats(lbls):
    """ `parse_labels_for_repeats('ijkji') -> 'kji`, {i, j}"""
    unique_labels = unique_from_end(lbls)
    repeat_lbls = set(i for i in unique_labels if lbls.count(i) > 1)
    return unique_labels, repeat_lbls

test_unique_from_end()

In [4]:
def lbl_to_size_mapping(lbls, shape):
    mapping = {}
    for lbl, size in zip(lbls, shape):
        if mapping.get(lbl, 0) < size:
            mapping[lbl] = size
    return mapping

def merge_max_mappings(*mappings):
    assert len(mappings) > 0
    mapping = mappings[0]
    for mapp in mappings:
        for key, val in mapp.items():
            if mapping.get(key, 0) < val:
                mapping[key] = val
    return mapping

In [68]:
script = 'ijkji,ik->jk'
i = 2
j = 3
k = 4
x = arr(i,j,k,j,i)
y = arr(i,k)
grad = arr(j, k)
xlbls = "ijkji"
ylbls = "ik"
grad_lbls = "jk"
xshape = x.shape
yshape = y.shape

var_lbl, repeat_lbls = parse_labels_for_repeats(xlbls)

In [69]:
np.einsum(script, x, y)

array([[ 292,  401,  534,  691],
       [ 396,  557,  742,  951],
       [ 500,  713,  950, 1211]])

In [70]:
size_mappings = ({k: v for k, v in zip(lbl, arr.shape)}
                 for lbl, arr in zip([xlbls, ylbls], [x, y]))
lbl_to_size = merge_max_mappings(*size_mappings)

In [71]:
out = np.zeros(tuple(lbl_to_size[i] for i in xlbls))

In [72]:
out_view_shape = tuple(lbl_to_size[i] for i in var_lbl)

In [73]:
def get_indices(item, seq): return (n for n, x in enumerate(seq) if x == item) 

In [74]:
strides = tuple(sum(out.strides[ind] for ind in get_indices(lbl, xlbls)) for lbl in var_lbl)

In [75]:
strides

(48, 208, 584)

In [76]:
out_view = as_strided(out, shape=out_view_shape, strides=strides)

In [77]:
var_lbl

'kji'

In [78]:
grad_lbls

'jk'

In [79]:
grad.shape

(3, 4)

In [80]:
grad_lbls+','+ylbls+"->"+var_lbl

'jk,ik->kji'

In [84]:
np.einsum(grad_lbls+','+ylbls+"->"+var_lbl, grad, y, out=out_view)

array([[[ 0.,  0.],
        [ 0., 16.],
        [ 0., 32.]],

       [[ 1.,  5.],
        [ 5., 25.],
        [ 9., 45.]],

       [[ 4., 12.],
        [12., 36.],
        [20., 60.]],

       [[ 9., 21.],
        [21., 49.],
        [33., 77.]]])

In [85]:
out

array([[[[[ 0.,  0.],
          [ 0.,  0.],
          [ 0.,  0.]],

         [[ 1.,  0.],
          [ 0.,  0.],
          [ 0.,  0.]],

         [[ 4.,  0.],
          [ 0.,  0.],
          [ 0.,  0.]],

         [[ 9.,  0.],
          [ 0.,  0.],
          [ 0.,  0.]]],


        [[[ 0.,  0.],
          [ 0.,  0.],
          [ 0.,  0.]],

         [[ 0.,  0.],
          [ 5.,  0.],
          [ 0.,  0.]],

         [[ 0.,  0.],
          [12.,  0.],
          [ 0.,  0.]],

         [[ 0.,  0.],
          [21.,  0.],
          [ 0.,  0.]]],


        [[[ 0.,  0.],
          [ 0.,  0.],
          [ 0.,  0.]],

         [[ 0.,  0.],
          [ 0.,  0.],
          [ 9.,  0.]],

         [[ 0.,  0.],
          [ 0.,  0.],
          [20.,  0.]],

         [[ 0.,  0.],
          [ 0.,  0.],
          [33.,  0.]]]],



       [[[[ 0.,  0.],
          [ 0.,  0.],
          [ 0.,  0.]],

         [[ 0.,  5.],
          [ 0.,  0.],
          [ 0.,  0.]],

         [[ 0., 12.],
          [ 0.,  0

In [28]:
unique_from_end(out_lbls)

'jik'

In [None]:
def f(in_lbls, out_lbls, *vars):
    

In [25]:
i = 3
x = np.arange(i)
out = np.zeros((i,i,i,i))
out_view = as_strided(out, shape=(i,),
                     strides=((sum(out.strides),)))
out_view[:] = x

In [26]:
out

array([[[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       [[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 1., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]],


       [[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 2.]]]])

In [27]:
np.einsum("ijij->ij", out)

array([[0., 0., 0.],
       [0., 1., 0.],
       [0., 0., 2.]])