In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange

In [None]:
init_method_dict = {
    'zero': nn.init.zeros_,
    'constant': lambda w: nn.init.constant_(w, 0.01),
    'uniform': nn.init.uniform_,
    'kaiming-he': nn.init.kaiming_normal_,
    'xavier-glorot': nn.init.xavier_uniform_,
}


class Model(nn.Module):
    def __init__(self, model_args):
        super().__init__()

        self.conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        
        # model_args.in_dim -> Flatten 후의 입력 차원이라고 가정
        # model_args.out_dim -> 최종 출력 차원 (예: 분류 클래스 수 등)
        self.linear = nn.Linear(model_args.in_dim, model_args.out_dim, bias=model_args.bias)

        # 초깃값 세팅
        self.reset_parameters(model_args.init_method)


    def reset_parameters(self, init_method):
        if init_method not in init_method_dict:
            raise ValueError(f"Unknown init_method: {init_method}")
        init_fn = init_method_dict[init_method]

        for name, module in self.named_modules():
            # 자기 자신(Model)도 named_modules에 등장하므로, 분기 필요
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                # weight 초기화
                init_fn(module.weight)
                # bias 초기화
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)

        # (B, C, H, W)를 (B, -1)로 펼쳐서 Linear에 투입
        x = x.flatten(start_dim=1)

        x = self.linear(x)
        return x

### 출력 클래스
- 단순 dict 반환보다 유연한 설계, 가독성
- IDE 수준에서 오타 발견

In [15]:
from typing import Optional, Tuple
from dataclasses import dataclass
from collections import OrderedDict

class ModelOutput(OrderedDict):
    def __init__(self, **kwargs):
        super().__init__()
        self.__dict__.update(kwargs)

    def __getitem__(self, key):
        if isinstance(key, int):
            key = list(self.__dict__.keys())[key]
        return getattr(self, key, None)

    def __setitem__(self, key, value):
        setattr(self, key, value)
        super().__setitem__(key, value)


@dataclass(init=False)
class EncoderOutput(ModelOutput):
    last_hidden_state = None
    hidden_states = None

In [19]:
t = EncoderOutput(
            last_hidden_state=1,
            hidden_states=2)
t.last_hidden_state, t.hidden_states

(1, 2)

In [20]:
class Encoder(nn.Module):
    def __init__(
            self, 
            embed_dim=16, num_layers=2, return_intermediate=False
    ):
        super().__init__()
        self.return_intermediate = return_intermediate
        
        # 임베딩 레이어 (예시: Flatten + Linear 등)
        self.to_patch_embed = nn.Sequential(
            nn.Conv2d(3, embed_dim, kernel_size=4, stride=4),  # (B,3,H,W)->(B,embed_dim,H/4,W/4)
            nn.Flatten(start_dim=2),                           # (B, embed_dim, (H/4)*(W/4))
            Rearrange('b c s -> b s c')                                 # (B, seq, embed_dim)
        )

        # Transformer 레이어들을 매우 간단히 흉내
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
            for _ in range(num_layers)
        ])

    def forward(self, x):
        # 1) Patch Embedding
        out = self.to_patch_embed(x)   # shape: (B, seq_len, embed_dim)

        hidden_states = []
        current = out
        # 2) Encoder layers 순회
        for layer in self.layers:
            current = layer(current)   # shape 유지: (B, seq_len, embed_dim)
            if self.return_intermediate:
                hidden_states.append(current)

        # 3) 마지막 hidden state
        last_hidden_state = current   # (B, seq_len, embed_dim)

        # 4) Output
        # ModelOutput 형태로 반환
        return EncoderOutput(
            last_hidden_state=last_hidden_state,
            hidden_states=hidden_states if self.return_intermediate else None
        )

In [21]:
x = torch.randn(2, 3, 32, 32)
model = Encoder(embed_dim=16, num_layers=2, return_intermediate=True)
output = model(x)


In [22]:
print("last_hidden_state shape:", output.last_hidden_state.shape)
if output.hidden_states is not None:
    for i, hs in enumerate(output.hidden_states):
        print(f"hidden_states[{i}] shape:", hs.shape)

# 딕셔너리 처럼도 접근 가능
print("Check by dict key:", output["last_hidden_state"].shape)
# 인덱스로도 가능 (0 -> 'last_hidden_state', 1-> 'hidden_states')
print("Check by index [0]:", output[0].shape if output[0] is not None else None)

last_hidden_state shape: torch.Size([2, 64, 16])
hidden_states[0] shape: torch.Size([2, 64, 16])
hidden_states[1] shape: torch.Size([2, 64, 16])
Check by dict key: torch.Size([2, 64, 16])
Check by index [0]: torch.Size([2, 64, 16])
