# Linear Layer

## Summary

The Linear TransformationIn the context of modern neural networks, a linear layer (also known as a fully connected or dense layer) is typically represented as:

$$y = xW^T + b$$

**Component Breakdown**

* $x$: The input tensor. In a Transformer, this is usually a vector of shape `[batch_size, sequence_length, input_features]`.
* $W^T$: The transpose of the weight matrix.
* $b$: The bias vector (optional), which is added to the result of the matrix multiplication.
* $y$: The output tensor.
<!-- $$
\displaystyle
y = x W^T
$$ -->

### Implimention

In [1]:
import torch
import math
from torch import Tensor
from jaxtyping import Float
import torch.nn as nn
from torch.nn.parameter import Parameter

class Linear(nn.Module):
    __constants__ = ["in_features", "out_features"]
    
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(
        self,
        in_features: int,                   # input dimensions
        out_features: int,                  # output dimensions
        device: torch.device | None = None, # CPU or GPU
        dtype: torch.dtype | None = None    # float32, 64, etc
    ) -> None:
        
        factory_kwargs = {"device": device, "dtype": dtype}
        
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(
            torch.empty((out_features, in_features), **factory_kwargs)
        )
        self.reset_parameters()
    
    def reset_parameters(self) -> None:
        nn.init.trunc_normal_(self.weight)
    
    def forward(self, x: torch.Tensor) -> Tensor:
        return x @ self.weight.T

## Step-by-Step

#### Definition `Linear` class

In [None]:
import torch
import math
from torch import Tensor
from jaxtyping import Float
import torch.nn as nn
from torch.nn.parameter import Parameter

# Inherit super class constructor `nn.Module`
class Linear(nn.Module):

    # these are fixed configuration values of the layer
    __constants__ = ["in_features", "out_features"]
    
    in_features: int
    out_features: int
    weight: Tensor
    # weight is a PyTorch Tensor and stores the weight matrix

    def __init__(
        self,
        in_features: int,                   # input dimensions
        out_features: int,                  # output dimensions
        device: torch.device | None = None, # CPU or GPU
        dtype: torch.dtype | None = None    # float32, 64, etc
    ) -> None:
        
        # dictionary is later passed into torch.empty() **operator means: it unpacks the dictionary into keyword arguments
        factory_kwargs = {"device": device, "dtype": dtype}
        
        # inherit parent class constructor
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

        # Creating the Weight Matrix
        # If you don't wrap it in Parameter, the optimizer won't update it. so it means trainable weight matrix
        self.weight = Parameter(
            torch.empty((out_features, in_features), **factory_kwargs)
        )
        
        self.reset_parameters()


### Initialization

**`trunc_normal_`** initialization

![Truncated Normal Distribution](/home/yohei.ohata.ee/Projects/stanford-cs336/assignment1-basics/cs336_basics/images/TruncatedNormal.png)

Fill the input Tensor with values drawn from a truncated normal distribution.

The values are effectively drawn from the normal distribution $ \mathcal{N}(\textrm{mean},\textrm{std}^2)$ with values outside $[a,b]$ redrawn until they are within the bounds. 

The method used for generating the random values works best when $a≤\textrm{mean}≤b$.

In [None]:
def reset_parameters(self) -> None:
    nn.init.trunc_normal_(self.weight)

### Matrix Multiplation

In [None]:
def forward(self, x: torch.Tensor) -> Tensor:
    return x @ self.weight.T

#### Appendix

**PyTorch Weight Storage & Matrix Multiplication**

In PyTorch, linear layers typically store weights in the shape `(out_features, in_features)`. 
While it might seem counterintuitive compared to the standard mathematical notation $y = W x$, <br>
there are significant hardware advantages to this approach.

> **Note:** Even if you stored the weight as `(in_features, out_features)`, the operation `x @ self.weight` would technically work, but it wouldn't benefit from the same level of low-level optimization.


**Memory Layout and GEMM**

* **Storage Shape:** `(out_features, in_features)` row-major form
* **Execution:** When you run `x @ weight.T` which is row-form major, PyTorch leverages highly optimized routines that maximize cache hits.

---

**Memory Layout: Row-Major vs. Column-Major**

The way a language stores a 2D array in linear memory (RAM) dictates how fast certain operations will be.

![Row-Major vs. Column-Major](/home/yohei.ohata.ee/Projects/stanford-cs336/assignment1-basics/cs336_basics/images/row-major-column-major.webp)


**Comparison Table**

| Feature | Row-Major Order | Column-Major Order |
| --- | --- | --- |
| **Storage Logic** | Elements are stored row-by-row. | Elements are stored column-by-column. |
| **Adjacent Elements** |  and  are neighbors. |  and  are neighbors. |
| **Performance** | Faster row-wise access. | Faster column-wise access. |
| **Primary Use** | General purpose programming. | Scientific & Mathematical computing. |

**Language Ecosystems**

| Type | Key Languages |
| --- | --- |
| **Row-Major** (C-Style) | C, C++, Python (NumPy default), Pascal, SAS, HLSL |
| **Column-Major** (Fortran-Style) | Fortran, MATLAB, R, Julia, Scilab, OpenGL (GLSL) |

### Example

In [5]:
# 1. Setup dimensions
batch_size = 4
in_features = 8
out_features = 4

# 2. Instantiate your custom layer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Linear(in_features, out_features, device=device)

# 3. Create a dummy input tensor
# Shape: (Batch Size, In Features)
input_tensor: Float[Tensor, "batch in_features"] = torch.randn(batch_size, in_features).to(device)

# 4. Run the forward pass
output = model(input_tensor)

# 5. Check the results
print(f"Input Shape:  {input_tensor.shape}")
print("Input Tensor:")
print(input_tensor)
print(f"\nWeight Shape: {model.weight.shape}")
print("Weight Tensor:")
print(model.weight)
print(f"\nOutput Shape: {output.shape}")
print("Output Tensor:")
print(output)

Input Shape:  torch.Size([4, 8])
Input Tensor:
tensor([[-1.6999, -0.4738, -0.5925,  0.4202,  0.1482,  0.4658, -1.4073, -1.0404],
        [ 0.6700,  0.9281, -1.8762,  0.3649,  0.5327, -1.8633,  0.0222,  0.6440],
        [ 1.4492,  0.6457,  0.0057, -0.6342,  0.8637,  0.4045,  0.0803, -0.0030],
        [ 2.4281,  0.1656,  0.7150, -1.2204, -2.6843,  0.8813, -1.1786,  0.4195]])

Weight Shape: torch.Size([4, 8])
Weight Tensor:
Parameter containing:
tensor([[ 1.7370,  0.4927, -0.4564, -0.9130,  0.8827,  0.0034,  0.6451,  0.0218],
        [ 1.3956, -0.2552,  0.8936, -0.4262, -0.3371, -0.3544,  0.1371,  0.9575],
        [ 0.7492,  0.0960, -1.3545, -0.7660,  0.8891, -1.1542,  0.8908,  0.2677],
        [ 0.1911,  0.5428, -0.2852,  0.9995, -0.6744,  0.6710,  0.3653, -0.5579]],
       requires_grad=True)

Output Shape: torch.Size([4, 4])
Output Tensor:
tensor([[-4.0976, -4.3640, -2.7765,  0.2859],
        [ 2.6365, -0.0334,  5.6693, -0.4291],
        [ 4.2273,  1.7066,  1.9976, -0.2881],
        [ 