n_embd - входной параметр

forward input with shape : (b_size, seq_size, n_embd) <=> (B,T,C)

In [1]:
import os

from dotenv import load_dotenv, find_dotenv

import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms

import wandb

__ENV_FILE = find_dotenv(f'{os.getenv("ENV", "var")}.env')
load_dotenv(__ENV_FILE)

True

In [2]:
class KAN(nn.Module):
    def __init__(self, fin, hid=None, fout=None):
        super().__init__()
        self.fin = fin
        self.hid = hid
        self.fout = fout

        self.last = nn.Linear(fin, fout)
        self.softm = nn.Softmax(dim=0)

        
        if hid is None or fout is None:
            self._init_default_hid_and_fout_values()
        
    def _init_default_hid_and_fout_values(self):
        self.hid = self.fin * 4
        self.fout = self.fin
        
    
    def forward(self, X):
        #  X : (B,T,C), C = n_emb
        X = self.last(X)
        X = self.softm(X)
        return X

In [9]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.fin = nn.Softmax(dim=1)
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.fin(out)
        
        return out

In [12]:
wandb_api_key = os.environ.get('WANDB_API_KEY')
wandb.login(key=wandb_api_key)

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
trainset = torch.utils.data.Subset(dataset, [i for i in range(100)])
train_loader = DataLoader(trainset)


input_size = 784
hidden_size = 128
output_size = 10
num_epochs = 10
batch_size = 784
learning_rate = 1e-3

run = wandb.init(
    project="KAN",
    config={
        "learning_rate": learning_rate,
        "epochs": num_epochs,
        "entity": "staff",
        "group": "creating_kan",
        "name": "KAN",
    },
)

model = MLP(input_size, hidden_size, output_size)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
dataloader = train_loader

wandb.watch(model, log="all")
for epoch in range(num_epochs):
    for batch, labels in dataloader:
        for inputs in batch:
            inputs = inputs.view(inputs.shape[0], -1)
            
            logits = model(inputs)
            pred_label = torch.argmax(logits, dim=1)
            loss = criterion(logits, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            wandb.log({"loss": loss})



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,▇▇▇▇▂▁▂▁▂██▂▃▂▃▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,1.46376


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111220639997757, max=1.0)…

loss:2.3152287006378174
loss:2.291801691055298
loss:2.2991888523101807
loss:2.301562786102295
loss:2.3174638748168945
loss:2.308088541030884
loss:2.287564516067505
loss:2.3184525966644287
loss:2.2882418632507324
loss:2.290271282196045
loss:2.303473711013794
loss:2.3023018836975098
loss:2.2901663780212402
loss:2.3298068046569824
loss:2.2687857151031494
loss:2.325249433517456
loss:2.298975944519043
loss:2.347548007965088
loss:2.3096251487731934
loss:2.31057071685791
loss:2.256385564804077
loss:2.2368617057800293
loss:2.3096978664398193
loss:2.2646541595458984
loss:2.287102699279785
loss:2.301349401473999
loss:2.287783622741699
loss:2.1574082374572754
loss:2.3033392429351807
loss:2.3210654258728027
loss:2.248307704925537
loss:2.3479163646698
loss:2.3015708923339844
loss:2.304605722427368
loss:2.2854015827178955
loss:2.2995572090148926
loss:2.3211350440979004
loss:2.1940360069274902
loss:2.3435261249542236
loss:2.2869760990142822
loss:2.244673252105713
loss:2.3548669815063477
loss:2.321493