In [3]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn import Transformer
import torch.nn.functional as F
from torch import Tensor
from torch import nn
import warnings
import random
import torch
import math
import yaml
import json
import os
# warnings.filterwarnings("ignore")

In [4]:
class index2char():
    def __init__(self, root, tokenizer=None):
        if tokenizer is None:
            with open(root + '/tokenizer.yaml', 'r') as f:
                self.tokenizer = yaml.load(f, Loader=yaml.CLoader)
        else:
            self.tokenizer = tokenizer
    
    def __call__(self, indices:list, without_token=True):
        if type(indices) == Tensor:
            indices = indices.tolist()
        result = ''.join([self.tokenizer['index_2_char'][i] for i in indices])
        if without_token:
            result = result.split('[eos]')[0]
            result = result.replace('[sos]', '').replace('[eos]', '').replace('[pad]', '')
        return result

In [5]:
def metrics(pred:list, target:list) -> float:
    """
    pred: list of strings
    target: list of strings

    return: accuracy(%)
    """
    if len(pred) != len(target):
        raise ValueError('length of pred and target must be the same')
    correct = 0
    for i in range(len(pred)):
        if pred[i] == target[i]:
            correct += 1
    return correct / len(pred) * 100

In [6]:
embedding_num = 31
embedding_dim = 512
num_layers = 8
num_heads = 8
ff_dim = 1024
dropout = 0.1

In [7]:
class SpellCorrectionDataset(Dataset):
    def __init__(self, root, split:str = 'train', tokenizer=None, padding:int =0):
        super(SpellCorrectionDataset, self).__init__()
        #load your data here
        self.padding = padding
        
        if tokenizer:
            self.tokenizer = tokenizer
        else:
            with open(os.path.join(root, 'tokenizer.yaml'), 'r') as f:
                self.tokenizer = yaml.load(f, Loader=yaml.CLoader)
        
        data_path = os.path.join(root, f'{split}.json')
        with open(data_path, 'r') as f:
            self.all_data = json.load(f)
        self.data =[]
        for line in range(len(self.all_data)):
            for input in self.all_data[line]['input']:
                self.data.append({'input':input,'target':self.all_data[line]['target']})
    
    def tokenize(self, text:str):
        # tokenize your text here
        # ex: "data" -> [4, 1, 20, 1]
        
        # 將文本轉換為索引序列
        tokens = [self.tokenizer['char_2_index'].get(char, 0) for char in text]  # 0 可以是未識別字符的索引
        # 根據指定的padding進行填充或截斷
        if self.padding > 0:
            tokens = tokens[:self.padding] + [0] * max(0, self.padding - len(tokens))  # 使用 0 進行填充
        return tokens
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # get your data by index here
        # ex: return input_ids, target_ids
        # return type: torch.tensor
        item = self.data[index]
        input_text = item['input']
        input_ids = self.tokenize(input_text)
        target_text = item['target']
        target_ids = self.tokenize(target_text)
        # print(f"input_ids={input_ids}")
        # print(f"input_text={input_text}")
        # print(f"target_text={target_text}")
        # print(f"target_ids={target_ids}")
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)

In [8]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, batch_first: bool = False):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        self.batch_first = batch_first

    def forward(self, x: Tensor) -> Tensor:
        if self.batch_first:
            x = x.transpose(0, 1)
            # print(f"x shape: {x.shape}, pe shape: {self.pe[:x.size(0)].shape}")
            x = x + self.pe[:x.size(0)]
            return self.dropout(x.transpose(0, 1))
        else:
            x = x + self.pe[:x.size(0)]
            return self.dropout(x)

# Transformer

In [9]:
class Encoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100):
        super(Encoder, self).__init__()
        self.tok_embedding = nn.Embedding(num_emb, hid_dim)
        self.pos_embedding = PositionalEncoding(hid_dim, dropout, max_length, batch_first=True)
        # self.layer = <nn.TransformerEncoderLayer>
        # self.encoder = <nn.TransformerEncoder>
        self.layer = nn.TransformerEncoderLayer(d_model=hid_dim, nhead=n_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.encoder = nn.TransformerEncoder(self.layer, num_layers=n_layers)
        

    def forward(self, src, src_mask=None, src_pad_mask=None):
        # tgt = your_embeddings(?)
        src_emb = self.tok_embedding(src)
        src_emb = self.pos_embedding(src_emb)
        # src = self.encoder(?)
        enc_src = self.encoder(src_emb, mask=src_mask, src_key_padding_mask=src_pad_mask)
        
        return enc_src

class Decoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100):
        super(Decoder, self).__init__()
        self.tok_embedding = nn.Embedding(num_emb, hid_dim)
        self.pos_embedding = PositionalEncoding(hid_dim, dropout, max_length, batch_first=True)
        # self.layer = <nn.TransformerDecoderLayer>
        # self.encoder = <nn.TransformerDecoder>
        self.layer = nn.TransformerDecoderLayer(d_model=hid_dim, nhead=n_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.decoder = nn.TransformerDecoder(self.layer, num_layers=n_layers)

    def forward(self, tgt, enc_src, tgt_mask=None, memory_mask=None, src_pad_mask=None, tgt_key_padding_mask=None):
        # tgt = your_embeddings(?)
        tgt_emb = self.tok_embedding(tgt)
        tgt_emb = self.pos_embedding(tgt_emb)
        
        # tgt = self.decoder(?)
        dec_output = self.decoder(tgt_emb, enc_src, tgt_mask=tgt_mask,
                                  memory_mask=memory_mask,
                                  tgt_key_padding_mask=tgt_key_padding_mask,
                                  memory_key_padding_mask=src_pad_mask)
        
        return dec_output

class TransformerAutoEncoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100, encoder=None):
        super(TransformerAutoEncoder, self).__init__()
        if encoder is None:
            self.encoder = Encoder(num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length)
        else:
            self.encoder = encoder
        self.decoder = Decoder(num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length)
        self.output_layer = nn.Linear(hid_dim, num_emb)

    def forward(self, src, tgt, src_pad_mask=None, tgt_mask=None, tgt_pad_mask=None):
        # enc_src = self.encoder(?)
        enc_src = self.encoder(src, src_pad_mask=src_pad_mask)
        
        # out = self.decoder(?)
        dec_out = self.decoder(tgt, enc_src, src_pad_mask=src_pad_mask, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_pad_mask)  

        out = self.output_layer(dec_out)
        return out

# Mask

In [10]:
def gen_padding_mask(src, pad_idx):
    # detect where the padding value is
    pad_mask = (src == pad_idx).transpose(0, 1)  # 生成布林掩碼，padding 位置為 True
    return pad_mask

def gen_mask(seq):
    # triu mask for decoder
    seq_len = seq.size(0)  # 獲取序列長度
    # 生成上三角掩碼，並擴展以適配注意力的計算
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask

def get_index(pred, dim=1):
    return pred.clone().argmax(dim=dim)

def random_change_idx(data: torch.Tensor, prob: float = 0.2):
    # randomly change the index of the input data
    change_mask = torch.rand(data.size()) < prob  # 隨機生成的布林掩碼
    new_values = torch.randint(0, data.size(-1), data.size(), dtype=torch.long)
    # 用新值替換原始數據中的部分值
    sample = data.clone()
    sample[change_mask] = new_values[change_mask]
    return sample

def random_masked(data: torch.Tensor, prob: float = 0.2, mask_idx: int = 3):
    # randomly mask the input data
    mask = torch.rand(data.size()) < prob  # 隨機生成的布林掩碼
    # 將選中的值替換為指定的 mask_idx
    sample = data.clone()
    sample[mask] = mask_idx
    return sample

# Pretrained encoder with random mask

In [11]:
# You can try to pretrain the Encoder here!
class PretrainedMaskedEncoder(Encoder):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100, pretrained=True):
        super(PretrainedMaskedEncoder, self).__init__(num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length)
        self.pretrained = pretrained

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Apply random masking before passing through the encoder
        if self.pretrained:
            src = random_masked(src, prob=0.2, mask_idx=0)  # Apply random masking to input data
        return super().forward(src, src_mask, src_key_padding_mask)

encoder = PretrainedMaskedEncoder(embedding_num, embedding_dim, num_layers, num_heads, ff_dim, dropout,  pretrained=True)

# Define some dummy input (batch_size=4, sequence_length=10)
src = torch.randint(0, embedding_num, (4, 10))

# Forward pass through the encoder
output = encoder(src)

print(output.shape)

torch.Size([4, 10, 512])




# Train our spelling correction transformer

In [12]:
from tqdm import tqdm
i2c = index2char('./data/')

trainset = SpellCorrectionDataset('./data/', padding=22)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testset = SpellCorrectionDataset('./data/', split='new_test', padding=22)
testloader = DataLoader(testset, batch_size=32, shuffle=False)
valset = SpellCorrectionDataset('./data/', split='test', padding=22)
valloader = DataLoader(valset, batch_size=32, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ce_loss = nn.CrossEntropyLoss(ignore_index=0)

In [13]:
def validation(dataloader, model, device, logout=False):
    pred_str_list = []
    tgt_str_list = []
    input_str_list = []
    losses = []
    for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)
            tgt_input = torch.full((tgt.shape[0], tgt.shape[1]), fill_value=0, device=device)
            tgt_input[:, 0] = 1  # Set first token to <sos>
            
            for i in range(tgt.shape[1]-1):
                # Generate the padding masks
                src_pad_mask = gen_padding_mask(src, pad_idx=0)
                tgt_pad_mask = gen_padding_mask(tgt_input, pad_idx=0)
                tgt_mask = gen_mask(tgt_input).to(device)
                
                # Forward pass
                pred = model(src, tgt_input,
                             src_pad_mask=src_pad_mask,
                             tgt_mask=tgt_mask,
                             tgt_pad_mask=tgt_pad_mask)
                
                # pred = <get the prediction idx from the model>
                # assign the prediction idx to the next token of tgt_input
                pred_idx = get_index(pred[:, -1, :])
                tgt_input[:, i + 1] = pred_idx  # 將預測結果添加到 tgt_input
                
            for i in range (tgt.shape[0]):
                pred_str_list.append(i2c(tgt_input[i].tolist()))
                tgt_str_list.append(i2c(tgt[i].tolist()))
                input_str_list.append(i2c(src[i].tolist()))
                if logout:
                    print('='*30)
                    print(f'input: {input_str_list[-1]}')
                    print(f'pred: {pred_str_list[-1]}')
                    print(f'target: {tgt_str_list[-1]}')
            loss = ce_loss(pred[:, :-1, :].permute(0, 2, 1), tgt[:, 1:])
            losses.append(loss.item())
    print(f"test_acc: {metrics(pred_str_list, tgt_str_list):.2f}", f"test_loss: {sum(losses)/len(losses):.2f}", end=' | ')
    print(f"[pred: {pred_str_list[0]} target: {tgt_str_list[0]}]")

In [18]:
# encoder.pretrained_mode = False
model = TransformerAutoEncoder(embedding_num, embedding_dim, num_layers, num_heads, ff_dim, dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1) #choose your optimizer

for eps in range(1000):
    # train
    losses = []
    model.train()
    i_bar = tqdm(trainloader, unit='iter', desc=f'epoch{eps}')
    
    for src, tgt in i_bar:
        src, tgt = src.to(device), tgt.to(device)
        
        # generate the mask and padding mask
        src_pad_mask = gen_padding_mask(src, pad_idx=0)  # Generate the padding mask
        tgt_pad_mask = gen_padding_mask(tgt, pad_idx=0)  # Generate the padding mask
        tgt_mask = gen_mask(tgt).to(device)  # Generate the mask
        
        optimizer.zero_grad()
        
        pred = model(src, tgt,
                     src_pad_mask=src_pad_mask,
                     tgt_mask=tgt_mask,
                     tgt_pad_mask=tgt_pad_mask)
        
        # print(f"pred shape: {pred.shape}")
        
        loss = ce_loss(pred[:, :-1, :].permute(0, 2, 1), tgt[:, 1:])
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        i_bar.set_postfix_str(f"loss: {sum(losses)/len(losses):.3f}")
        
    # test
    model.eval()
    with torch.no_grad():
        validation(testloader, model, device)
    model.eval()
    with torch.no_grad():
        validation(valloader, model, device)

epoch0:   1%|          | 3/404 [00:00<00:15, 26.53iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:   2%|▏         | 9/404 [00:00<00:14, 26.86iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:   4%|▍         | 18/404 [00:00<00:13, 28.10iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:   5%|▌         | 22/404 [00:00<00:13, 28.98iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:   8%|▊         | 31/404 [00:01<00:12, 28.76iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:   8%|▊         | 34/404 [00:01<00:13, 28.23iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  10%|▉         | 40/404 [00:01<00:13, 27.56iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  12%|█▏        | 49/404 [00:01<00:12, 28.19iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  13%|█▎        | 52/404 [00:01<00:12, 27.94iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  14%|█▍        | 58/404 [00:02<00:12, 27.80iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  16%|█▌        | 65/404 [00:02<00:11, 28.80iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  18%|█▊        | 71/404 [00:02<00:11, 28.32iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  19%|█▉        | 77/404 [00:02<00:11, 28.02iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  21%|██        | 83/404 [00:03<00:11, 28.09iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  22%|██▏       | 89/404 [00:03<00:11, 28.15iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  24%|██▎       | 95/404 [00:03<00:11, 27.79iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  25%|██▌       | 101/404 [00:03<00:10, 27.71iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  26%|██▋       | 107/404 [00:03<00:10, 27.90iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  28%|██▊       | 113/404 [00:04<00:10, 27.60iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  29%|██▉       | 119/404 [00:04<00:10, 27.06iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  31%|███       | 125/404 [00:04<00:10, 27.09iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  32%|███▏      | 131/404 [00:04<00:10, 27.22iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  34%|███▍      | 137/404 [00:04<00:09, 27.25iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  35%|███▌      | 143/404 [00:05<00:09, 27.57iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  37%|███▋      | 149/404 [00:05<00:09, 27.45iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  39%|███▊      | 156/404 [00:05<00:08, 28.76iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  40%|████      | 162/404 [00:05<00:08, 27.90iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  42%|████▏     | 168/404 [00:06<00:08, 27.99iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  43%|████▎     | 174/404 [00:06<00:08, 28.21iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])


epoch0:  45%|████▍     | 180/404 [00:06<00:08, 27.83iter/s, loss: nan]

pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])
pred shape: torch.Size([32, 22, 31])





KeyboardInterrupt: 

In [None]:
validation(testloader, model, device, logout=True)

In [None]:
validation(valloader, model, device, logout=True)