In [10]:
import logging
from typing import Callable, List, Optional, Tuple, Union

import torch
from torch import nn as nn
import torch.nn.functional as F

from itertools import repeat
import collections.abc


# From PyTorch internals
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple


class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    dynamic_img_pad: torch.jit.Final[bool]

    def __init__(
            self,
            img_size: Optional[int] = 224,
            patch_size: int = 16,
            in_chans: int = 3,
            embed_dim: int = 768,
            norm_layer: Optional[Callable] = None,
            flatten: bool = True,
            bias: bool = True,
            strict_img_size: bool = True,
            dynamic_img_pad: bool = False,
    ):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)
        if img_size is not None:
            self.img_size = to_2tuple(img_size)
            self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
            self.num_patches = self.grid_size[0] * self.grid_size[1]
        else:
            self.img_size = None
            self.grid_size = None
            self.num_patches = None

        self.strict_img_size = strict_img_size
        self.dynamic_img_pad = False

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
        self.flatten = False
    def forward(self, x):
        B, C, H, W = x.shape
        if self.dynamic_img_pad:
            pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
            pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
            x = F.pad(x, (0, pad_w, 0, pad_h))
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # NCHW -> NLC
        x = self.norm(x)
        return x

In [14]:
for s, p in zip((10, 3), (5, 5)):
    print(s,p)

10 5
3 5


In [15]:
test1 = torch.randn(1, 10+1, 30)
x = torch.randn(1, 11, 30)

y = x  + test1

In [18]:
test1.shape

torch.Size([1, 11, 30])

In [23]:
torch.stack([torch.tensor(10), torch.tensor(20), torch.tensor(30)])

tensor([10, 20, 30])

In [24]:
torch.stack([torch.arange(10)])

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [26]:
torch.stack([torch.randn(10, 20)]).shape

torch.Size([1, 10, 20])

In [11]:
pm = PatchEmbed(224, 16, 3, 768)

In [12]:
pm(torch.randn(1, 3, 224, 224)).shape

torch.Size([1, 768, 14, 14])

In [34]:
from dataclasses import dataclass
from collections import OrderedDict, UserDict

def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


@dataclass
class BaseModelOutput(OrderedDict):
    test = None
    test2 = None

In [36]:
t = OrderedDict(test=100, test2=200)

In [38]:
t

OrderedDict([('test', 100), ('test2', 200)])

In [50]:
class ModelOutput(OrderedDict):
    # def __init__(self, **kwargs):
    #     self.data = kwargs
    #     self.keys = list(kwargs.keys())
        
    def __post_init__(self):
        self.keys = list(self.data.keys())

    def __getitem__(self, key):
        if isinstance(key, int):
            key = self.keys[key]
        return self.data.get(key)

    def __setitem__(self, key, value):
        self.data[key] = value
        if key not in self.keys:
            self.keys.append(key)

In [51]:

@dataclass
class MyOutput(ModelOutput):
    test = None
    test2 = None

In [39]:
class MinimalModelOutput(OrderedDict):
    def __init__(self, **kwargs):
        self.data = kwargs
        self.keys = list(kwargs.keys())

    def __getitem__(self, key):
        if isinstance(key, int):
            # 인덱스로 조회
            key = self.keys[key]
        # 키로 조회
        return self.data.get(key)

    def __setitem__(self, key, value):
        # 데이터 설정
        self.data[key] = value
        if key not in self.keys:
            self.keys.append(key)

    def __getattr__(self, key):
        if key in self.data:
            return self.data[key]
        raise AttributeError(f"'MinimalModelOutput' object has no attribute '{key}'")


In [59]:
class ModelOutput(OrderedDict):
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def __getitem__(self, key):
        if isinstance(key, int):
            key = list(self.__dict__.keys())[key]
        return getattr(self, key, None)

    def __setitem__(self, key, value):
        setattr(self, key, value)

from typing import Optional, Tuple
import torch

@dataclass
class BaseModelOutput(ModelOutput):
    test = None
    test2 = None

In [63]:
t = ModelOutput(test=100, test2=330)
t[1]

330

In [52]:
t = MyOutput(test=100, test2=200)

TypeError: __init__() got an unexpected keyword argument 'test'

In [46]:
t['test']

100

In [35]:
BaseModelOutput(test=100, test2=200)

TypeError: __init__() got an unexpected keyword argument 'test'

In [30]:
t = torch.randn(4, 3, 10)
a, b =t.chunk(2)
a.shape

torch.Size([2, 3, 10])

In [4]:
class test():
    def __init__(self, *,a,b):
        super().__init__()
        self.a = a

In [None]:
test(a=10, b=20)

In [None]:
class LlamaAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        

        if (head_dim * num_heads) != hidden_size:
            raise ValueError(f"not divisible")
        
        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, self.num_heads, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)


        self._init_rope()
        
    def _init_rope(self):
        self.rotary_emb = ...

    def forward(self, hidden_states : torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.LongTensor] = None,
                past_key_value: Optional[Cache] = None,
                output_attentions: bool = False,
                use_cache: bool = False,
                **kwargs,) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

        bsz, q_len, _ = hidden_states.size()
        

In [None]:
import torcha

from einops import rearrange
import torcah.nn as nn

def exists(val):
    return val is not None

def divisible_by(num, den):
    return (num % den) == 0

class ViT(nn.Moduel):
    def __init__(
        self, image_size, patch_size, attn_layers, channels=3, num_classes=None,
        post_emb_n
    )