In [1]:
import torch

In [2]:

# Create example tensors of different types
float_tensor = torch.tensor([1.0, 2.0, 3.0])
int_tensor = torch.tensor([1, 2, 3])
bool_tensor = torch.tensor([True, False, True])
complex_tensor = torch.complex(real=torch.rand(3), imag=torch.rand(3))


# ===== Basic Type Checking =====

In [3]:
# Get dtype of tensor
dtype = float_tensor.dtype  # returns torch.float32

# Check specific types
is_float = float_tensor.dtype == torch.float32
is_int = int_tensor.dtype == torch.int64
is_bool = bool_tensor.dtype == torch.bool

# Built-in type checking methods
is_floating = float_tensor.is_floating_point()  # True for float32, float64, float16, bfloat16
is_complex = complex_tensor.is_complex()        # True for complex64, complex128
is_signed = int_tensor.is_signed()              # True for signed integer types
is_inference = float_tensor.is_inference()      # True for inference-optimized dtypes

# ===== Detailed Type Checking =====

In [4]:
def check_tensor_type(tensor):
    """Comprehensive type checking for a PyTorch tensor"""
    type_info = {
        # Basic type information
        'dtype': tensor.dtype,
        'type': type(tensor),
        
        # Numeric type checks
        'is_floating_point': tensor.is_floating_point(),
        'is_complex': tensor.is_complex(),
        'is_signed': tensor.is_signed(),
        
        # Memory format checks
        'is_contiguous': tensor.is_contiguous(),
        'is_pinned': tensor.is_pinned() if torch.cuda.is_available() else False,
        
        # Device checks
        'is_cuda': tensor.is_cuda,
        'is_cpu': tensor.device.type == 'cpu',
        
        # Other properties
        'requires_grad': tensor.requires_grad,
        'is_leaf': tensor.is_leaf,
        'is_sparse': tensor.is_sparse,
        'is_quantized': tensor.is_quantized,
        'is_nested': tensor.is_nested if hasattr(tensor, 'is_nested') else False,
    }
    return type_info

# ===== Size and Shape Checking =====

In [5]:
def check_tensor_shape(tensor):
    """Check tensor shape and dimension properties"""
    shape_info = {
        'shape': tensor.shape,
        'ndim': tensor.ndim,
        'size': tensor.size(),
        'numel': tensor.numel(),  # total number of elements
        'is_scalar': tensor.ndim == 0,
        'is_vector': tensor.ndim == 1,
        'is_matrix': tensor.ndim == 2
    }
    return shape_info

# ===== Type Compatibility Checking =====

In [6]:
def check_type_compatibility(tensor1, tensor2):
    """Check if two tensors have compatible types for operations"""
    compatibility = {
        'same_dtype': tensor1.dtype == tensor2.dtype,
        'same_device': tensor1.device == tensor2.device,
        'both_floating': tensor1.is_floating_point() and tensor2.is_floating_point(),
        'both_complex': tensor1.is_complex() and tensor2.is_complex(),
        'can_add': torch.can_cast(tensor1.dtype, tensor2.dtype)
    }
    return compatibility

In [7]:

# Create tensors of different types
float32_tensor = torch.tensor([1.0, 2.0], dtype=torch.float32)
float64_tensor = torch.tensor([1.0, 2.0], dtype=torch.float64)
int32_tensor = torch.tensor([1, 2], dtype=torch.int32)

# Check type information
print("Float32 Tensor Type Info:")
print(check_tensor_type(float32_tensor))

# Check shape information
print("\nShape Info:")
print(check_tensor_shape(float32_tensor))

# Check compatibility
print("\nType Compatibility:")
print(check_type_compatibility(float32_tensor, float64_tensor))


Float32 Tensor Type Info:
{'dtype': torch.float32, 'type': <class 'torch.Tensor'>, 'is_floating_point': True, 'is_complex': False, 'is_signed': True, 'is_contiguous': True, 'is_pinned': False, 'is_cuda': False, 'is_cpu': True, 'requires_grad': False, 'is_leaf': True, 'is_sparse': False, 'is_quantized': False, 'is_nested': False}

Shape Info:
{'shape': torch.Size([2]), 'ndim': 1, 'size': torch.Size([2]), 'numel': 2, 'is_scalar': False, 'is_vector': True, 'is_matrix': False}

Type Compatibility:
{'same_dtype': False, 'same_device': True, 'both_floating': True, 'both_complex': False, 'can_add': True}



# ===== Type Safety Checks =====

In [9]:
def safe_operation(tensor1, tensor2, operation='add'):
    """Safely perform operation after type checking"""
    if not tensor1.is_floating_point() and not tensor2.is_floating_point():
        raise TypeError("At least one tensor should be floating point")
    
    if tensor1.device != tensor2.device:
        raise ValueError("Tensors must be on the same device")
    
    if operation == 'add':
        return torch.add(tensor1, tensor2)
    elif operation == 'multiply':
        return torch.mul(tensor1, tensor2)
    else:
        raise ValueError(f"Unsupported operation: {operation}")

# ===== Memory Format Checking =====
def check_memory_format(tensor):
    """Check tensor's memory format"""
    format_info = {
        'is_contiguous': tensor.is_contiguous(),
        'is_contiguous_channels_last': tensor.is_contiguous(memory_format=torch.channels_last),
        'is_contiguous_channels_last_3d': tensor.is_contiguous(memory_format=torch.channels_last_3d),
        'stride': tensor.stride(),
    }
    return format_info