In [None]:
import torch

def multiply_elements_of_tensor(data: torch.Tensor, dim: int = 0, keepdim: bool = True):
  """
  Set dim=0 multiplying across columns and dim=1 for across rows.
  """
  real_dim_is_one = False
  if data.dim() == 1:
    data = data.unsqueeze(dim=0)
    real_dim_is_one = True
  
  if dim == 0:
    prod_matrix = torch.empty((0,), dtype=torch.float)
    for i in range(0, data.shape[0]):
      product = torch.tensor([1.0], dtype=torch.float)
      for j in range(0, data.shape[1]):
        product *= data[i,j]
      prod_matrix = torch.cat((prod_matrix, product), dim=0)  
  else:
    prod_matrix = torch.empty((0,), dtype=torch.float)
    for j in range(0, data.shape[1]):
      product = torch.tensor([1.0], dtype=torch.float)
      for i in range(0, data.shape[0]):
        product *= data[i,j]
      prod_matrix = torch.cat((prod_matrix, product), dim=0)

  if keepdim:
    if dim == 0:
      prod_matrix = prod_matrix.view(-1, 1)
    else:
      prod_matrix = prod_matrix.view(1, -1)
  
  if real_dim_is_one:
    return prod_matrix[0,0]
  else:
    return prod_matrix
    

def get_determinant(matrix: torch.Tensor) -> float:
    """
    Calculates the determinant of 2x2 matrix
    """
    indices = torch.arange(start=0, end=matrix.shape[0], step = 1)
    reversed_indices = torch.arange(start=matrix.shape[0]-1, end=-1, step = -1)
    major_diagonal = matrix[(indices, indices)]
    minor_diagonal = matrix[(indices, reversed_indices)]
    prod_major_diagonal = multiply_elements_of_tensor(data=major_diagonal, dim=0).item()
    prod_minor_diagonal = multiply_elements_of_tensor(data=minor_diagonal, dim=0).item()
    
    determinant = prod_major_diagonal - prod_minor_diagonal

    return determinant


def inverse_2x2(matrix) -> torch.Tensor | None:
    """
    Compute inverse of a 2Ã2 matrix using PyTorch.
    Input can be Python list, NumPy array, or torch Tensor.
    Returns a 2Ã2 tensor or None if the matrix is singular.
    """
    m = torch.as_tensor(matrix, dtype=torch.float)
    # Your implementation here
    if get_determinant(matrix=m) == 0:
		return None
	else:
		return torch.linalg.inv(m)
