<a href="https://colab.research.google.com/github/ppujari/PyTorch/blob/main/PyTorch_tips.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# prompt: pytorch tips

# Use `torch.cuda.is_available()` to check if CUDA is available before using it.
# This can help avoid errors and improve performance.

# Use `torch.no_grad()` context manager to disable gradient computation when it is not needed.
# This can save memory and improve performance.

# Use `torch.jit.trace()` to create a traced version of your model for improved performance.
# This can be especially useful for models that are used frequently.

# Use `torch.utils.data.DataLoader` to load your data in batches.
# This can help improve performance by reducing the number of times the data is loaded into memory.

# Use `torch.optim.lr_scheduler` to adjust the learning rate during training.
# This can help improve the convergence of the model.

# Use `torch.nn.utils.clip_grad_norm_` to clip the gradients during training.
# This can help prevent the gradients from becoming too large and causing the model to diverge.

# Use `torch.utils.tensorboard` to visualize the training process.
# This can help you track the progress of the model and identify any potential problems.


### PyTorch Tip - 3


**Using torch.where() for Conditional Element-wise Operations**
PyTorch's **torch.where()** function allows you to perform conditional element-wise operations efficiently. It takes three arguments: the condition, the tensor to select values from when the condition is true, and the tensor to select values from when the condition is false. This is particularly useful for implementing conditional logic within your neural network models.

Here's a quick example:
In this example, the elements from tensor_true are selected where the condition is True, and elements from tensor_false are selected where the condition is False. This allows for flexible conditional operations within your PyTorch code.

In [1]:
import torch

# Define tensors
condition = torch.tensor([[True, False], [False, True]])
tensor_true = torch.tensor([[1, 2], [3, 4]])
tensor_false = torch.tensor([[5, 6], [7, 8]])

# Perform conditional element-wise operation
result = torch.where(condition, tensor_true, tensor_false)
print(result)


tensor([[1, 6],
        [7, 4]])
