<a href="https://colab.research.google.com/github/szhou12/gpt-from-scratch/blob/main/pytorch_funcs_review.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# torch.nn.Embedding(num_embeddings, embedding_dim)
- `nn.Embedding(n, d)`: an Embedding module (table) containing `n` tensors of size `d`.
- use `.weight` to show content of embedding table.
- Official Doc: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html

In [2]:
vocab_size = 4
embedding_table = nn.Embedding(vocab_size, vocab_size)
print(embedding_table.weight)


Parameter containing:
tensor([[-0.2741, -0.4954,  0.6855,  1.7356],
        [ 0.2998,  0.9350,  1.4132,  0.1495],
        [ 1.5562, -0.4624, -0.8990, -0.2129],
        [ 0.7861,  0.6712,  0.8209, -0.7189]], requires_grad=True)


# .to(device)
- Move the model's parameters (weights and biases) to a specified computing device (e.g. GPU).
- In-place operation for `nn.Module` objects, meaning `model` itself is moved to the device.
- It's common practice to write as `m = model.to(device)`. However, `m` is just another reference as `model`, meaning they are both moved to the same device. This line could be simplified to just `model.to(device)` instead of `m = model.to(device)` if the separate reference `m` is not specifically needed for later use.

In [19]:
# Runtime -> Change runtime type -> select 'T4 GPU' to use 'cuda'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Example model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 5)  # A simple linear layer

model = SimpleModel()
m = model.to(device)

# Check: Iterate through all parameters in 'model' and print their device
for name, param in model.named_parameters():
    print(f"model: {name} is on {param.device}")

# Check: Iterate through all parameters in 'm' and print their device
for name2, param2 in m.named_parameters():
    print(f"m: {name2} is on {param2.device}")

model: linear.weight is on cuda:0
model: linear.bias is on cuda:0
m: linear.weight is on cuda:0
m: linear.bias is on cuda:0
