# import libraries

In [50]:
import torchvision
from torchvision.datasets import VisionDataset
import torchvision.datasets as dset
import torch
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision import transforms
import os
from pathlib import Path
import wandb
import pytorch_lightning as pl
import numpy as np
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

# wandb setting to log

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mrespwill[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
wandb.init(
  mode='disabled',
  # Set the project where this run will be logged
  project="Monet image gan project", 
  # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
  name="Test13", 
  # Track hyperparameters and run metadata
  config={
      "learning_rate": 0.0001,
      "batch_size":8,
      "latent_dim":100,
      "b1":0.5,
      "b2":0.999,
  })
wandb_logger = WandbLogger()

  rank_zero_warn(


# Load dataset with transform

In [None]:
class MonetDataset(Dataset):
    def __init__(self, data_monet, data_photo):
        self.data_monet = data_monet
        self.data_photo = data_photo
    
    def __len__(self):
        return min(len(self.data_monet), len(self.data_photo))
    
    def __getitem__(self, idx):
        

In [47]:
class MonetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir1, data_dir2, batch_size, transform=None):
        super().__init__()
        self.data_dir_monet = data_dir1
        self.data_dir_photo = data_dir2
        self.batch_size = batch_size
        if transform == None:
            self.transform = transforms.Compose([
                transforms.ToTensor()
            ])
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
            ])
    
    def setup(self):
        dataset_monet = dset.VisionDataset(self.data_dir_monet, transform=self.transform)
        dataset_photo = dset.VisionDataset(self.data_dir_photo, transform=self.transform)
        self.dataset = torch.utils.data.ConcatDataset([dataset_monet, dataset_photo])
        
#         self.dataset_train, self.dataset_val = random_split(dataset, [int(len(dataset)*0.7), int(len(dataset)*0.3)])
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)
    
#     def val_dataloader(self):
#         return torch.utils.data.DataLoader(self.dataset_val, batch_size=self.batch_size)

In [48]:
dm = MonetDataModule('./dataset/gan-getting-started/monet_jpg/','./dataset/gan-getting-started/photo_jpg/', wandb.config['batch_size'], transform=True)

In [49]:
dm.setup()

NotImplementedError: 

In [33]:
for t in dm.train_dataloader():
    print(t[0])
    print(t[1])
    break

tensor([[[[-0.3725, -0.3725, -0.3961,  ...,  0.4588,  0.1765,  0.2549],
          [-0.3882, -0.3961, -0.4118,  ...,  0.5608,  0.5373, -0.0980],
          [-0.4353, -0.4353, -0.4588,  ..., -0.2078,  0.1765, -0.0353],
          ...,
          [-0.9529, -0.9529, -0.9686,  ..., -0.3412, -0.6000, -0.5059],
          [-0.8275, -0.8353, -0.8745,  ..., -0.0039,  0.4510, -0.4039],
          [-0.8824, -0.8902, -0.9294,  ...,  0.4353, -0.3725,  0.0510]],

         [[-0.4039, -0.3725, -0.3255,  ...,  0.0039, -0.2706, -0.1765],
          [-0.4039, -0.3882, -0.3412,  ...,  0.0980,  0.0980, -0.5373],
          [-0.4275, -0.3961, -0.3647,  ..., -0.6706, -0.2706, -0.4667],
          ...,
          [-1.0000, -1.0000, -1.0000,  ..., -0.7490, -1.0000, -0.9059],
          [-0.8980, -0.8902, -0.9137,  ..., -0.5373, -0.0745, -0.9294],
          [-0.9529, -0.9451, -0.9922,  ..., -0.1451, -0.9686, -0.5451]],

         [[-0.4980, -0.4510, -0.3333,  ..., -0.6392, -0.9529, -0.8902],
          [-0.4980, -0.4510, -

# Set initial weight of Generator and Discriminator
* reference urls: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0)

# Generator

In [7]:
class Generator(pl.LightningModule):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        self.main = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Flatten(),
            torch.nn.Linear(latent_dim, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(128, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(256, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(),
#             torch.nn.Dropout(0.3),
#             torch.nn.Linear(512, 1024),
#             torch.nn.BatchNorm1d(1024, 0.8),
#             torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(512, int(np.prod(img_shape))),            
            torch.nn.Tanh()
        )
        
    def forward(self, input):
        img = self.main(input)
        # change shape of tensor from network
        # size(0)->batch size?
        img = img.view(img.size(0), *self.img_shape)
        return img

# Discriminator

In [8]:
class Discriminator(pl.LightningModule):
    def __init__(self, img_shape):
        super().__init__()
        self.main = torch.nn.Sequential(
            torch.nn.Linear(int(np.prod(img_shape)), 512),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(512, 256),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(256, 1),
            torch.nn.Sigmoid()
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.main(img_flat)
        return validity

# GAN

In [9]:
class GAN(pl.LightningModule):
    def __init__(self, channels, width, height, latent_dim, lr, b1, b2, batch_size):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False
        
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, 
                                   img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)
        self.validation_z = torch.randn(8, self.hparams.latent_dim)
        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
        
    def forward(self, z):
        return self.generator(z)
    
    def adversarial_loss(self, y_hat, y):
        return torch.nn.functional.binary_cross_entropy(y_hat, y)
    
    def training_step(self, batch):
        imgs, _ = batch
#         print(imgs)
        optimizer_g, optimizer_d = self.optimizers()
        
        # add noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)
        
        # train generator
        self.toggle_optimizer(optimizer_g, optimizer_idx=0)
        self.generated_imgs = self(z)
        # get 6 samples from generated images
        sample_imgs = self.generated_imgs[:6]
        # make flattened data into grid shape to see as image
        grid = torchvision.utils.make_grid(sample_imgs)
#         self.logger.log_image("generated_images", [grid])
        
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)
        
        # binary cross-entropy
        # if generator created images well so that discriminator recognize the results as real(1), loss decrease.
        g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
        self.log('g_loss', g_loss, prog_bar=True, on_epoch=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        # stop generator training
        self.untoggle_optimizer(optimizer_g)
        
        # train discriminator
        self.toggle_optimizer(optimizer_d, optimizer_idx=1)
        
        # check if discriminator can recognize real images as real
        # low loss means discriminator recognize real images well.
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs)
        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
        
        # check if discriminator can recognize fake images as fake
        # low loss means discriminator recognize fake images.
        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs)
        fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)
        
        # overall loss is average of these two loss.
        d_loss = (real_loss + fake_loss) / 2
        # low d_loss means discriminator recognize images well.
        self.log('d_loss', d_loss, prog_bar=True, on_epoch=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        
        # stop discriminator training.
        self.untoggle_optimizer(optimizer_d)
    
    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []
    
    def validation_step(self, *args):
        pass
    
    def on_validation_epoch_end(self):
#     def validation_step(self):
#         print("it is the end of validation epoch")
        z = self.validation_z.type_as(self.generator.main[0].weight)
        
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
#         grid = wandb.Image(grid)
#         self.logger.experiment.add_image("generated_images", grid, self.current_epoch)        
        self.logger.log_image("generated_images", [grid])

In [10]:
param = {"channels":3,
         "width":256, 
         "height":256,
         "latent_dim":wandb.config['latent_dim'],
         "lr":wandb.config['learning_rate'],
         "b1":wandb.config['b1'],
         "b2":wandb.config['b2'],
         "batch_size":wandb.config['batch_size']}
model = GAN(**param)

In [11]:
trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=1000,
    log_every_n_steps=27,
    logger=wandb_logger
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [12]:
trainer.fit(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params | In sizes | Out sizes       
------------------------------------------------------------------------------
0 | generator     | Generator     | 101 M  | [2, 100] | [2, 3, 256, 256]
1 | discriminator | Discriminator | 100 M  | ?        | ?               
------------------------------------------------------------------------------
201 M     Trainable params
0         Non-trainable params
201 M     Total params
807.338   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


it is the end of validation epoch


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

tensor([[[[ 0.7176,  0.7176,  0.7333,  ...,  0.6784,  0.6941,  0.7176],
          [ 0.7176,  0.7098,  0.7255,  ...,  0.7176,  0.7255,  0.7255],
          [ 0.7098,  0.7098,  0.7255,  ...,  0.7255,  0.7255,  0.7098],
          ...,
          [ 0.5451,  0.5608,  0.5529,  ...,  0.6706,  0.5765,  0.6314],
          [ 0.5294,  0.5529,  0.5451,  ...,  0.5294,  0.6078,  0.5451],
          [ 0.4980,  0.5451,  0.5373,  ...,  0.5608,  0.6706,  0.6157]],

         [[ 0.7490,  0.7490,  0.7333,  ...,  0.6863,  0.7020,  0.7255],
          [ 0.7490,  0.7412,  0.7412,  ...,  0.7255,  0.7333,  0.7333],
          [ 0.7412,  0.7412,  0.7412,  ...,  0.7333,  0.7333,  0.7176],
          ...,
          [ 0.4824,  0.4980,  0.5059,  ...,  0.6706,  0.5765,  0.6549],
          [ 0.4667,  0.4902,  0.4824,  ...,  0.5529,  0.6314,  0.5765],
          [ 0.4353,  0.4824,  0.4745,  ...,  0.5765,  0.7020,  0.6471]],

         [[ 0.4588,  0.4588,  0.4353,  ...,  0.5765,  0.5765,  0.6000],
          [ 0.4588,  0.4510,  

tensor([[[[ 0.4196,  0.5373,  0.5922,  ...,  0.4980,  0.5294,  0.5529],
          [ 0.4039,  0.5373,  0.6000,  ...,  0.5451,  0.5529,  0.5608],
          [ 0.3961,  0.5451,  0.6157,  ...,  0.5529,  0.5922,  0.6157],
          ...,
          [ 0.3412,  0.4902,  0.5216,  ..., -0.6157, -0.6471, -0.7176],
          [ 0.1922,  0.3569,  0.4510,  ..., -0.6157, -0.6471, -0.6941],
          [ 0.1216,  0.3647,  0.3725,  ..., -0.6235, -0.6549, -0.6863]],

         [[ 0.4118,  0.5294,  0.5686,  ...,  0.5608,  0.5608,  0.5843],
          [ 0.3961,  0.5294,  0.5765,  ...,  0.6078,  0.5843,  0.5922],
          [ 0.4039,  0.5529,  0.6078,  ...,  0.6157,  0.6235,  0.6471],
          ...,
          [ 0.4196,  0.5765,  0.5843,  ..., -0.0824, -0.1137, -0.1922],
          [ 0.2863,  0.4353,  0.5137,  ..., -0.0824, -0.1137, -0.1686],
          [ 0.2157,  0.4588,  0.4353,  ..., -0.0902, -0.1216, -0.1608]],

         [[ 0.3804,  0.4980,  0.5294,  ...,  0.5765,  0.5686,  0.5922],
          [ 0.3647,  0.4980,  

tensor([[[[-0.1451, -0.0196,  0.0824,  ..., -0.0980, -0.1059, -0.3490],
          [ 0.0196,  0.2314,  0.2392,  ...,  0.0039, -0.0431, -0.3098],
          [ 0.0745,  0.3961,  0.3255,  ...,  0.0824,  0.0431, -0.2392],
          ...,
          [-0.2392, -0.2941, -0.4431,  ...,  0.2549,  0.1843, -0.2078],
          [-0.2000, -0.2627, -0.3882,  ...,  0.2235,  0.1451, -0.0902],
          [-0.2471, -0.3569, -0.5137,  ...,  0.2471,  0.0667, -0.0039]],

         [[ 0.0275,  0.0980,  0.1059,  ...,  0.1137,  0.1059, -0.1373],
          [ 0.1922,  0.3490,  0.2627,  ...,  0.2157,  0.1529, -0.0980],
          [ 0.2471,  0.5137,  0.3490,  ...,  0.2784,  0.2157, -0.0431],
          ...,
          [-0.0353, -0.1059, -0.2784,  ..., -0.1059, -0.0588, -0.3961],
          [ 0.1373,  0.0510, -0.1608,  ..., -0.0980, -0.0824, -0.2627],
          [ 0.1608,  0.0118, -0.2549,  ..., -0.0588, -0.1373, -0.1608]],

         [[ 0.0039,  0.0431, -0.0275,  ...,  0.0745,  0.0667, -0.1765],
          [ 0.1686,  0.2941,  

tensor([[[[-0.3647, -0.0824, -0.3412,  ..., -0.3412, -0.4196, -0.4667],
          [-0.3804, -0.0118, -0.1843,  ..., -0.3412, -0.3020, -0.3725],
          [-0.6078, -0.3098, -0.2941,  ..., -0.1373, -0.1608, -0.2078],
          ...,
          [-0.5137, -0.7490, -0.5608,  ..., -0.1373, -0.1216, -0.0510],
          [-0.4588, -0.6784, -0.4745,  ..., -0.0667, -0.1059, -0.1373],
          [-0.3961, -0.5843, -0.3961,  ..., -0.0039, -0.1216, -0.2471]],

         [[-0.4902, -0.2078, -0.4510,  ..., -0.2941, -0.3961, -0.4431],
          [-0.4980, -0.1294, -0.2941,  ..., -0.2941, -0.2863, -0.3490],
          [-0.7176, -0.4196, -0.3961,  ..., -0.0980, -0.1529, -0.1922],
          ...,
          [-0.5373, -0.7647, -0.5765,  ..., -0.0510, -0.0510,  0.0039],
          [-0.4824, -0.6941, -0.4902,  ..., -0.0118, -0.0588, -0.0902],
          [-0.4196, -0.6000, -0.4118,  ...,  0.0588, -0.0902, -0.2157]],

         [[-0.4118, -0.1216, -0.3333,  ..., -0.4039, -0.5294, -0.5922],
          [-0.4588, -0.0745, -

tensor([[[[ 0.5373,  0.5216,  0.4824,  ...,  0.4588,  0.3725,  0.3569],
          [ 0.5216,  0.5059,  0.4980,  ...,  0.4196,  0.3804,  0.3804],
          [ 0.4824,  0.4824,  0.4902,  ...,  0.5216,  0.5137,  0.5373],
          ...,
          [ 0.1686,  0.0980,  0.1529,  ..., -0.1608, -0.2549, -0.2314],
          [ 0.0510,  0.0902,  0.0275,  ..., -0.1529, -0.1843, -0.0745],
          [ 0.2941,  0.2235, -0.0275,  ..., -0.0353, -0.0667,  0.0588]],

         [[ 0.8902,  0.8745,  0.8353,  ...,  0.7412,  0.6549,  0.6392],
          [ 0.8745,  0.8588,  0.8510,  ...,  0.6784,  0.6314,  0.6471],
          [ 0.8353,  0.8353,  0.8431,  ...,  0.7412,  0.7255,  0.7569],
          ...,
          [ 0.2000,  0.1451,  0.2000,  ..., -0.1686, -0.2627, -0.2235],
          [ 0.0824,  0.1216,  0.0745,  ..., -0.1294, -0.1608, -0.0510],
          [ 0.3176,  0.2627,  0.0039,  ..., -0.0118, -0.0275,  0.0980]],

         [[ 0.9137,  0.8980,  0.8588,  ...,  0.6627,  0.5765,  0.5451],
          [ 0.8980,  0.8824,  

tensor([[[[ 0.0039,  0.1843,  0.2549,  ..., -0.0980,  0.0431,  0.0588],
          [ 0.2235,  0.3647,  0.3804,  ..., -0.0902,  0.0980,  0.1451],
          [ 0.1529,  0.2941,  0.3176,  ..., -0.1373,  0.0745,  0.1686],
          ...,
          [-0.0510, -0.0588, -0.0118,  ...,  0.3098,  0.2863,  0.2784],
          [ 0.1216,  0.0824,  0.0510,  ...,  0.2863,  0.2863,  0.3490],
          [ 0.1922,  0.2078,  0.2549,  ...,  0.3020,  0.3333,  0.4824]],

         [[-0.2549, -0.0588,  0.0431,  ...,  0.2941,  0.1216, -0.0275],
          [-0.0431,  0.1216,  0.1686,  ...,  0.3098,  0.1843,  0.0980],
          [-0.0902,  0.0588,  0.1137,  ...,  0.2941,  0.2078,  0.1608],
          ...,
          [-0.0588, -0.0667, -0.0275,  ...,  0.0745, -0.0196, -0.0353],
          [ 0.0824,  0.0431,  0.0275,  ..., -0.0196, -0.0902, -0.0353],
          [ 0.1451,  0.1608,  0.2157,  ..., -0.0510, -0.0667,  0.0588]],

         [[-0.4980, -0.2863, -0.1843,  ...,  0.2863,  0.1373,  0.0039],
          [-0.2627, -0.1059, -

tensor([[[[ 0.4902,  0.4196,  0.5529,  ...,  0.5686,  0.5843,  0.5922],
          [ 0.4980,  0.4118,  0.5216,  ...,  0.5608,  0.5686,  0.5686],
          [ 0.4980,  0.4196,  0.5451,  ...,  0.5608,  0.5608,  0.5529],
          ...,
          [ 0.3725,  0.3412,  0.3333,  ...,  0.5137,  0.5529,  0.6471],
          [ 0.2863,  0.2784,  0.2706,  ...,  0.4118,  0.4745,  0.7176],
          [ 0.2392,  0.2314,  0.2157,  ...,  0.4902,  0.5529,  0.7804]],

         [[ 0.4196,  0.3490,  0.4824,  ...,  0.5608,  0.5765,  0.5843],
          [ 0.4275,  0.3412,  0.4510,  ...,  0.5529,  0.5608,  0.5608],
          [ 0.4196,  0.3490,  0.4745,  ...,  0.5529,  0.5529,  0.5451],
          ...,
          [ 0.4431,  0.4039,  0.3804,  ...,  0.2078,  0.2078,  0.3020],
          [ 0.3569,  0.3490,  0.3098,  ..., -0.0039,  0.0275,  0.2549],
          [ 0.3255,  0.3098,  0.2627,  ...,  0.0275,  0.0588,  0.2706]],

         [[ 0.4588,  0.3882,  0.5059,  ...,  0.7020,  0.7176,  0.7255],
          [ 0.4667,  0.3804,  

tensor([[[[-0.7412, -0.7490, -0.7725,  ..., -0.8588, -0.8431, -0.7490],
          [-0.7725, -0.7882, -0.8196,  ..., -0.8353, -0.8275, -0.7412],
          [-0.7569, -0.7804, -0.8118,  ..., -0.8196, -0.7961, -0.7098],
          ...,
          [ 0.0745, -0.2392, -0.2314,  ...,  0.6941,  0.6706,  0.6706],
          [ 0.0824, -0.2392, -0.2157,  ...,  0.6392,  0.6549,  0.6784],
          [ 0.1451, -0.1529, -0.1137,  ...,  0.6471,  0.7020,  0.7490]],

         [[-0.8510, -0.8588, -0.8824,  ..., -0.8902, -0.8745, -0.7804],
          [-0.8824, -0.8980, -0.9294,  ..., -0.8667, -0.8588, -0.7725],
          [-0.8667, -0.8902, -0.9216,  ..., -0.8510, -0.8275, -0.7412],
          ...,
          [ 0.0824, -0.2157, -0.2078,  ...,  0.4667,  0.4667,  0.4667],
          [ 0.0980, -0.2000, -0.1686,  ...,  0.4275,  0.4431,  0.4667],
          [ 0.1843, -0.1137, -0.0588,  ...,  0.4353,  0.4902,  0.5373]],

         [[-0.8745, -0.8824, -0.9059,  ..., -0.9765, -0.9608, -0.8667],
          [-0.9059, -0.9216, -

tensor([[[[ 0.0510,  0.0667,  0.0902,  ...,  0.0275, -0.0118, -0.0745],
          [ 0.0745,  0.0824,  0.0902,  ...,  0.0667,  0.0196, -0.0431],
          [ 0.1137,  0.1059,  0.0980,  ...,  0.0118, -0.0353, -0.1059],
          ...,
          [-0.0196, -0.0039,  0.0039,  ..., -0.8431, -0.8196, -0.7961],
          [-0.0980, -0.0353,  0.0196,  ..., -0.8118, -0.7882, -0.7647],
          [-0.1451, -0.0510,  0.0510,  ..., -0.7804, -0.7569, -0.7412]],

         [[-0.0510, -0.0353, -0.0196,  ..., -0.0118, -0.0510, -0.1137],
          [-0.0275, -0.0196, -0.0196,  ...,  0.0275, -0.0196, -0.0824],
          [ 0.0118,  0.0039, -0.0118,  ..., -0.0275, -0.0745, -0.1451],
          ...,
          [-0.1216, -0.1059, -0.0902,  ..., -0.8667, -0.8431, -0.8196],
          [-0.1843, -0.1137, -0.0510,  ..., -0.8353, -0.8118, -0.7882],
          [-0.2314, -0.1294, -0.0275,  ..., -0.8039, -0.7804, -0.7647]],

         [[-0.4039, -0.3882, -0.3882,  ..., -0.2471, -0.2863, -0.3490],
          [-0.3804, -0.3725, -

tensor([[[[ 0.2863,  0.2314,  0.2078,  ...,  0.3647,  0.2941,  0.2471],
          [ 0.2706,  0.2314,  0.2235,  ...,  0.3569,  0.3020,  0.2549],
          [ 0.2471,  0.2314,  0.2549,  ...,  0.3569,  0.3020,  0.2706],
          ...,
          [ 0.0745,  0.0824,  0.0824,  ..., -0.4431, -0.4902, -0.6235],
          [ 0.0510,  0.0745,  0.0824,  ..., -0.3725, -0.3569, -0.5216],
          [ 0.0275,  0.0667,  0.1059,  ..., -0.0431,  0.0275, -0.1373]],

         [[ 0.5373,  0.4824,  0.4588,  ...,  0.5922,  0.5216,  0.4745],
          [ 0.5216,  0.4824,  0.4745,  ...,  0.5843,  0.5294,  0.4824],
          [ 0.4980,  0.4824,  0.5059,  ...,  0.5843,  0.5294,  0.4980],
          ...,
          [ 0.2314,  0.2314,  0.2314,  ..., -0.2706, -0.3333, -0.4667],
          [ 0.2078,  0.2235,  0.2314,  ..., -0.2000, -0.2000, -0.3647],
          [ 0.1843,  0.2157,  0.2549,  ...,  0.1294,  0.1843,  0.0196]],

         [[ 0.6235,  0.5686,  0.5608,  ...,  0.5529,  0.4824,  0.4353],
          [ 0.6078,  0.5686,  

tensor([[[[-0.1451, -0.1451, -0.1529,  ..., -0.1529, -0.1137, -0.0980],
          [-0.2314, -0.2471, -0.2314,  ..., -0.1765, -0.1451, -0.1373],
          [-0.2784, -0.3098, -0.2863,  ..., -0.1451, -0.1373, -0.1373],
          ...,
          [-0.3333, -0.3176, -0.2549,  ..., -0.1137,  0.0118,  0.1059],
          [-0.3020, -0.2392, -0.2000,  ..., -0.1294,  0.0510,  0.0667],
          [-0.3725, -0.1529, -0.0824,  ..., -0.2549, -0.1216, -0.2392]],

         [[-0.0118,  0.0039, -0.0039,  ...,  0.1059,  0.1216,  0.1373],
          [-0.0980, -0.0980, -0.0745,  ...,  0.0824,  0.0902,  0.0980],
          [-0.1294, -0.1451, -0.1059,  ...,  0.1137,  0.1216,  0.1216],
          ...,
          [-0.1922, -0.1529, -0.1137,  ..., -0.0667,  0.0510,  0.1686],
          [-0.2000, -0.1216, -0.0980,  ..., -0.0745,  0.0980,  0.1294],
          [-0.2863, -0.0667,  0.0039,  ..., -0.2000, -0.0510, -0.1765]],

         [[ 0.2078,  0.2392,  0.2471,  ...,  0.6471,  0.6706,  0.6863],
          [ 0.1216,  0.1373,  

tensor([[[[-0.5137, -0.5059, -0.4745,  ..., -0.2627, -0.3098, -0.2314],
          [-0.4667, -0.4824, -0.4745,  ..., -0.2549, -0.3098, -0.2549],
          [-0.4196, -0.4588, -0.4588,  ..., -0.2471, -0.3333, -0.2941],
          ...,
          [-0.3490, -0.3725, -0.4667,  ...,  0.3098,  0.2392,  0.0510],
          [-0.3569, -0.1843, -0.2627,  ...,  0.0824,  0.0196, -0.0510],
          [-0.1765, -0.0118, -0.1294,  ...,  0.2235,  0.1922,  0.2314]],

         [[-0.3412, -0.3255, -0.2784,  ..., -0.0980, -0.1216, -0.0431],
          [-0.3020, -0.3020, -0.2627,  ..., -0.0667, -0.1216, -0.0510],
          [-0.2471, -0.2627, -0.2392,  ..., -0.0431, -0.1294, -0.0824],
          ...,
          [-0.3725, -0.4196, -0.5686,  ...,  0.2706,  0.2000,  0.0353],
          [-0.4902, -0.3333, -0.4510,  ...,  0.0745,  0.0118, -0.0431],
          [-0.3647, -0.2235, -0.3490,  ...,  0.2157,  0.2000,  0.2392]],

         [[-0.2314, -0.2000, -0.1216,  ...,  0.1137,  0.0824,  0.1608],
          [-0.1686, -0.1608, -

tensor([[[[-0.2000, -0.1686, -0.0980,  ...,  0.1137, -0.1529, -0.4980],
          [-0.2471, -0.1451, -0.0510,  ..., -0.1451,  0.1059,  0.1137],
          [-0.3020, -0.1373, -0.0039,  ..., -0.0431,  0.2392,  0.1765],
          ...,
          [-0.5686, -0.6000, -0.6549,  ..., -0.7333, -0.7255, -0.6392],
          [-0.5608, -0.5843, -0.6706,  ..., -0.6941, -0.7098, -0.6392],
          [-0.6235, -0.6235, -0.6941,  ..., -0.5922, -0.6471, -0.6078]],

         [[ 0.0510,  0.0824,  0.1765,  ...,  0.4196,  0.1608, -0.1765],
          [ 0.0039,  0.1059,  0.2157,  ...,  0.1608,  0.4039,  0.4353],
          [-0.0353,  0.1294,  0.2627,  ...,  0.2471,  0.5216,  0.4824],
          ...,
          [-0.3961, -0.4353, -0.5059,  ..., -0.8118, -0.8039, -0.7098],
          [-0.3961, -0.4196, -0.5059,  ..., -0.7804, -0.7961, -0.7255],
          [-0.4588, -0.4588, -0.5294,  ..., -0.6941, -0.7333, -0.6941]],

         [[ 0.5137,  0.5451,  0.6314,  ..., -0.4118, -0.7333, -1.0000],
          [ 0.4824,  0.5686,  

       device='cuda:0')


Validation: 0it [00:00, ?it/s]

it is the end of validation epoch
tensor([[[[ 0.3961,  0.4510,  0.4980,  ...,  0.0824,  0.0118, -0.0588],
          [ 0.3412,  0.3882,  0.4353,  ..., -0.0039, -0.0275, -0.0353],
          [ 0.3020,  0.3333,  0.3725,  ...,  0.0039, -0.0196, -0.0431],
          ...,
          [-0.5137, -0.5059, -0.4588,  ...,  0.0353,  0.0431,  0.0667],
          [-0.5294, -0.4902, -0.4118,  ...,  0.1059,  0.0588,  0.0510],
          [-0.5216, -0.5608, -0.5216,  ...,  0.1137,  0.1059,  0.1216]],

         [[ 0.3098,  0.3647,  0.4118,  ...,  0.0902,  0.0196, -0.0510],
          [ 0.2549,  0.3020,  0.3490,  ...,  0.0039, -0.0196, -0.0275],
          [ 0.2157,  0.2471,  0.2863,  ...,  0.0118, -0.0118, -0.0353],
          ...,
          [-0.5373, -0.5294, -0.4824,  ..., -0.0824, -0.0745, -0.0510],
          [-0.5686, -0.5137, -0.4353,  ...,  0.0039, -0.0431, -0.0510],
          [-0.5608, -0.6000, -0.5451,  ...,  0.0039, -0.0039,  0.0118]],

         [[ 0.1373,  0.1922,  0.2392,  ...,  0.0431, -0.0275, -0.098

tensor([[[[-0.3804, -0.3647, -0.3333,  ..., -0.4275, -0.5137, -0.6784],
          [-0.3961, -0.3804, -0.3412,  ..., -0.3882, -0.4745, -0.5922],
          [-0.3961, -0.3804, -0.3490,  ..., -0.4196, -0.4980, -0.5216],
          ...,
          [ 0.3412,  0.2314,  0.1294,  ..., -0.9216, -0.8980, -0.8745],
          [ 0.3412,  0.2471,  0.1451,  ..., -0.9216, -0.9059, -0.8902],
          [ 0.3098,  0.2863,  0.2627,  ..., -0.9216, -0.9216, -0.9059]],

         [[-0.0039,  0.0118,  0.0275,  ..., -0.3882, -0.4745, -0.6392],
          [-0.0275, -0.0039,  0.0196,  ..., -0.3490, -0.4353, -0.5529],
          [-0.0431, -0.0196, -0.0039,  ..., -0.3804, -0.4588, -0.4824],
          ...,
          [ 0.1137,  0.0275, -0.0745,  ..., -0.7176, -0.6784, -0.6549],
          [ 0.1137,  0.0353, -0.0667,  ..., -0.7176, -0.6863, -0.6706],
          [ 0.0745,  0.0745,  0.0431,  ..., -0.7176, -0.7020, -0.6863]],

         [[ 0.1765,  0.1922,  0.2157,  ..., -0.6235, -0.7098, -0.8745],
          [ 0.1765,  0.1765,  

tensor([[[[-0.7176, -0.6706, -0.5922,  ..., -0.6706, -0.6078, -0.4745],
          [-0.5137, -0.4275, -0.3098,  ..., -0.2471, -0.2549, -0.3725],
          [-0.5373, -0.4039, -0.2392,  ..., -0.3804, -0.3412, -0.3882],
          ...,
          [-0.4902, -0.4667, -0.2706,  ..., -0.4039, -0.2471, -0.6314],
          [-0.6000, -0.6392, -0.5373,  ..., -0.3882, -0.3490, -0.6235],
          [-0.6392, -0.5922, -0.5294,  ..., -0.2941, -0.4588, -0.6235]],

         [[-0.4510, -0.3882, -0.2941,  ..., -0.2471, -0.1843, -0.0510],
          [-0.2471, -0.1451, -0.0118,  ...,  0.1922,  0.1843,  0.0667],
          [-0.2549, -0.1216,  0.0588,  ...,  0.0667,  0.1059,  0.0588],
          ...,
          [-0.2706, -0.2471, -0.0510,  ..., -0.2471, -0.0667, -0.4353],
          [-0.3804, -0.4196, -0.3176,  ..., -0.2235, -0.1843, -0.4353],
          [-0.4196, -0.3725, -0.3098,  ..., -0.1451, -0.2941, -0.4588]],

         [[-0.4431, -0.3882, -0.2863,  ..., -0.0588,  0.0039,  0.1373],
          [-0.2392, -0.1451, -

tensor([[[[ 0.1137,  0.1765,  0.2392,  ...,  0.3647,  0.3412,  0.3176],
          [ 0.1765,  0.2078,  0.2392,  ...,  0.3961,  0.3647,  0.3412],
          [ 0.2235,  0.2392,  0.2392,  ...,  0.4039,  0.3804,  0.3569],
          ...,
          [ 0.0431,  0.0980,  0.1765,  ...,  0.1843,  0.1843,  0.1686],
          [-0.0275,  0.0039,  0.0588,  ...,  0.1294,  0.1373,  0.1373],
          [-0.0353, -0.0431, -0.0275,  ...,  0.0353,  0.0667,  0.0824]],

         [[ 0.2471,  0.3098,  0.3412,  ...,  0.3333,  0.3098,  0.2863],
          [ 0.3098,  0.3412,  0.3412,  ...,  0.3569,  0.3333,  0.3098],
          [ 0.3569,  0.3725,  0.3412,  ...,  0.3647,  0.3490,  0.3255],
          ...,
          [-0.1216, -0.0667,  0.0118,  ...,  0.1059,  0.0902,  0.0588],
          [-0.1843, -0.1529, -0.0980,  ...,  0.0510,  0.0431,  0.0431],
          [-0.1922, -0.2000, -0.1843,  ..., -0.0196, -0.0275, -0.0118]],

         [[ 0.3098,  0.3725,  0.4118,  ...,  0.2627,  0.2392,  0.2000],
          [ 0.3725,  0.4039,  

tensor([[[[ 0.4196,  0.4353,  0.4745,  ...,  0.6784,  0.6863,  0.6863],
          [ 0.4196,  0.4431,  0.4745,  ...,  0.6549,  0.6627,  0.6706],
          [ 0.4196,  0.4431,  0.4824,  ...,  0.6235,  0.6314,  0.6392],
          ...,
          [ 0.6157,  0.6000,  0.5922,  ...,  0.3412,  0.3176,  0.3098],
          [ 0.6471,  0.6314,  0.6157,  ...,  0.3490,  0.3020,  0.2706],
          [ 0.6549,  0.6392,  0.6235,  ...,  0.3647,  0.3020,  0.2627]],

         [[ 0.4196,  0.4353,  0.4510,  ...,  0.5686,  0.5843,  0.5843],
          [ 0.4196,  0.4431,  0.4510,  ...,  0.5529,  0.5608,  0.5686],
          [ 0.4196,  0.4431,  0.4588,  ...,  0.5216,  0.5294,  0.5373],
          ...,
          [ 0.5529,  0.5373,  0.5294,  ...,  0.3490,  0.3255,  0.3255],
          [ 0.5608,  0.5451,  0.5294,  ...,  0.3804,  0.3333,  0.3098],
          [ 0.5765,  0.5608,  0.5373,  ...,  0.3961,  0.3333,  0.3020]],

         [[-0.0039,  0.0118,  0.0196,  ...,  0.2627,  0.2549,  0.2549],
          [-0.0039,  0.0196,  

tensor([[[[ 0.2314,  0.1529,  0.1451,  ...,  0.3882,  0.3255,  0.3490],
          [ 0.2235,  0.1529,  0.1529,  ...,  0.3804,  0.3333,  0.3569],
          [ 0.2157,  0.1608,  0.1765,  ...,  0.3961,  0.3333,  0.3569],
          ...,
          [-0.5373, -0.5922, -0.4824,  ...,  0.0353, -0.1137,  0.0275],
          [-0.3647, -0.5373, -0.4118,  ..., -0.1373, -0.1765,  0.0824],
          [-0.3176, -0.5137, -0.3490,  ..., -0.2157, -0.2784,  0.1294]],

         [[ 0.2627,  0.2078,  0.2000,  ...,  0.4275,  0.3647,  0.3882],
          [ 0.2549,  0.2078,  0.2078,  ...,  0.4196,  0.3725,  0.3961],
          [ 0.2471,  0.2157,  0.2314,  ...,  0.4353,  0.3882,  0.4118],
          ...,
          [-0.3882, -0.4667, -0.3804,  ..., -0.1686, -0.3255, -0.2000],
          [-0.2392, -0.4196, -0.3255,  ..., -0.3804, -0.4431, -0.1922],
          [-0.2000, -0.4039, -0.2627,  ..., -0.4745, -0.5765, -0.1843]],

         [[ 0.3490,  0.2863,  0.2784,  ...,  0.4510,  0.3961,  0.4196],
          [ 0.3412,  0.2863,  

tensor([[[[ 0.8118,  0.6235,  0.6784,  ...,  0.6000,  0.5216,  0.5765],
          [ 0.7333,  0.6314,  0.2157,  ...,  0.5922,  0.5216,  0.5686],
          [ 0.6941,  0.4510,  0.3647,  ...,  0.5686,  0.5059,  0.5294],
          ...,
          [-0.2706, -0.4353, -0.4980,  ..., -0.4118, -0.7804, -0.7176],
          [-0.3961, -0.4667, -0.3412,  ..., -0.2157, -0.7176, -0.3647],
          [-0.2627, -0.3882, -0.3255,  ..., -0.2941, -0.4039, -0.1059]],

         [[ 0.8510,  0.6706,  0.6941,  ...,  0.5373,  0.4353,  0.4588],
          [ 0.7725,  0.6784,  0.2314,  ...,  0.5294,  0.4353,  0.4510],
          [ 0.7412,  0.4980,  0.3804,  ...,  0.5059,  0.4196,  0.4118],
          ...,
          [-0.3333, -0.4980, -0.5608,  ..., -0.3804, -0.7490, -0.6863],
          [-0.4275, -0.4980, -0.3882,  ..., -0.1843, -0.6863, -0.3333],
          [-0.2784, -0.4039, -0.3647,  ..., -0.2627, -0.3725, -0.0745]],

         [[ 0.7961,  0.5922,  0.5922,  ...,  0.4510,  0.4196,  0.4824],
          [ 0.7176,  0.6000,  

tensor([[[[ 0.3882,  0.4588,  0.5294,  ...,  0.3412,  0.2314,  0.1216],
          [ 0.4275,  0.4510,  0.4667,  ...,  0.2392,  0.2157,  0.1686],
          [ 0.4824,  0.4902,  0.4824,  ...,  0.2314,  0.1922,  0.1529],
          ...,
          [-0.4275, -0.5686, -0.4588,  ..., -0.1686, -0.0431, -0.0353],
          [-0.0353, -0.4039, -0.3804,  ...,  0.0118,  0.1608,  0.1059],
          [ 0.4275, -0.1529, -0.2627,  ...,  0.0431, -0.2078, -0.0745]],

         [[ 0.5294,  0.6000,  0.6627,  ...,  0.5059,  0.3961,  0.2863],
          [ 0.5686,  0.5922,  0.6000,  ...,  0.4039,  0.3804,  0.3333],
          [ 0.6157,  0.6235,  0.6157,  ...,  0.3961,  0.3569,  0.3176],
          ...,
          [-0.4431, -0.5843, -0.4588,  ..., -0.1686, -0.0588, -0.0431],
          [-0.0667, -0.4196, -0.3804,  ...,  0.0118,  0.1451,  0.0745],
          [ 0.3961, -0.1608, -0.2784,  ...,  0.0353, -0.2235, -0.1059]],

         [[ 0.4039,  0.4745,  0.5373,  ...,  0.3882,  0.2627,  0.1529],
          [ 0.4431,  0.4667,  

tensor([[[[ 0.0980,  0.1451,  0.1608,  ...,  0.3020,  0.1529,  0.0118],
          [ 0.0353,  0.0667,  0.0667,  ...,  0.2941,  0.1373, -0.0118],
          [ 0.0980,  0.1059,  0.0588,  ...,  0.2235,  0.0902, -0.0588],
          ...,
          [ 0.1686,  0.1922,  0.2392,  ..., -0.1608, -0.0588,  0.0118],
          [ 0.2471,  0.2706,  0.3098,  ..., -0.1294, -0.0824, -0.0431],
          [ 0.2235,  0.2549,  0.3176,  ..., -0.1216, -0.1137, -0.1216]],

         [[-0.1922, -0.1451, -0.1451,  ..., -0.0510, -0.1843, -0.3098],
          [-0.2549, -0.2235, -0.2392,  ..., -0.0431, -0.1608, -0.3020],
          [-0.1765, -0.1686, -0.2314,  ..., -0.0667, -0.1686, -0.3098],
          ...,
          [-0.1294, -0.1059, -0.0588,  ..., -0.4353, -0.3569, -0.2863],
          [-0.0980, -0.0745, -0.0353,  ..., -0.4039, -0.3804, -0.3412],
          [-0.1373, -0.1059, -0.0431,  ..., -0.3961, -0.4118, -0.4196]],

         [[-0.3412, -0.2941, -0.2784,  ..., -0.2314, -0.3333, -0.4667],
          [-0.4039, -0.3725, -

tensor([[[[-0.2784, -0.2471, -0.0039,  ...,  0.5137,  0.7412,  0.6941],
          [-0.0980, -0.3647, -0.4275,  ...,  0.4431,  0.7098,  0.8118],
          [-0.4353, -0.4824, -0.4902,  ...,  0.5843,  0.7961,  0.8196],
          ...,
          [-0.4039, -0.4118, -0.4196,  ..., -0.2314, -0.1843, -0.0667],
          [-0.0275, -0.1608, -0.3098,  ..., -0.1294, -0.1686, -0.2000],
          [ 0.2471,  0.1765,  0.0275,  ...,  0.2471,  0.3098,  0.3412]],

         [[-0.2706, -0.2471, -0.0196,  ...,  0.2157,  0.5137,  0.4902],
          [-0.0902, -0.3647, -0.4431,  ...,  0.1451,  0.4824,  0.6078],
          [-0.4353, -0.4824, -0.4902,  ...,  0.2941,  0.5451,  0.6000],
          ...,
          [-0.2627, -0.2471, -0.2549,  ..., -0.2549, -0.2235, -0.1059],
          [ 0.1451,  0.0196, -0.1294,  ..., -0.1529, -0.2078, -0.2392],
          [ 0.4353,  0.3569,  0.2314,  ...,  0.2392,  0.2706,  0.3020]],

         [[-0.3098, -0.2627, -0.0118,  ..., -0.4353, -0.1137, -0.1137],
          [-0.1294, -0.3804, -

tensor([[[[ 0.4667,  0.4510,  0.4275,  ...,  0.3020,  0.3333,  0.3490],
          [ 0.4667,  0.4510,  0.4353,  ...,  0.3333,  0.3647,  0.3725],
          [ 0.4824,  0.4745,  0.4588,  ...,  0.3804,  0.4118,  0.4196],
          ...,
          [ 0.0039,  0.0039, -0.0745,  ...,  0.5216,  0.4667,  0.2000],
          [ 0.0275,  0.0510,  0.0353,  ...,  0.6157,  0.5451,  0.3647],
          [-0.0510,  0.0039,  0.1294,  ...,  0.7412,  0.5843,  0.5294]],

         [[ 0.3804,  0.3647,  0.3412,  ...,  0.2706,  0.2784,  0.2941],
          [ 0.3804,  0.3647,  0.3490,  ...,  0.3020,  0.3098,  0.3176],
          [ 0.3725,  0.3647,  0.3490,  ...,  0.3490,  0.3569,  0.3647],
          ...,
          [-0.0275, -0.0275, -0.1216,  ...,  0.4196,  0.3647,  0.0980],
          [-0.0196,  0.0039, -0.0353,  ...,  0.5059,  0.4275,  0.2471],
          [-0.0980, -0.0431,  0.0588,  ...,  0.6235,  0.4667,  0.4118]],

         [[ 0.3490,  0.3333,  0.3098,  ...,  0.3804,  0.4039,  0.4196],
          [ 0.3490,  0.3333,  

tensor([[[[ 0.6471,  0.4980,  0.3961,  ...,  0.4588,  0.4902,  0.4745],
          [ 0.5216,  0.4118,  0.3569,  ...,  0.4275,  0.4431,  0.4353],
          [ 0.4510,  0.3961,  0.3961,  ...,  0.3961,  0.4118,  0.4118],
          ...,
          [ 0.5373,  0.5608,  0.5059,  ..., -0.3176, -0.3333, -0.2863],
          [ 0.4196,  0.3804,  0.3412,  ..., -0.3412, -0.3490, -0.2863],
          [ 0.3961,  0.3412,  0.3882,  ..., -0.3098, -0.3412, -0.3020]],

         [[ 0.5922,  0.4431,  0.3412,  ...,  0.1765,  0.2000,  0.1608],
          [ 0.4667,  0.3569,  0.3020,  ...,  0.1451,  0.1529,  0.1216],
          [ 0.3882,  0.3333,  0.3333,  ...,  0.1216,  0.1137,  0.0980],
          ...,
          [ 0.4824,  0.5137,  0.4588,  ..., -0.4588, -0.4745, -0.4275],
          [ 0.3882,  0.3490,  0.3098,  ..., -0.4824, -0.4902, -0.4275],
          [ 0.3647,  0.3098,  0.3569,  ..., -0.4510, -0.4824, -0.4431]],

         [[ 0.4667,  0.3176,  0.2157,  ...,  0.0039, -0.0039, -0.0431],
          [ 0.3412,  0.2314,  

tensor([[[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          ...,
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000]],

         [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          ...,
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000]],

         [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  

       device='cuda:0')


Validation: 0it [00:00, ?it/s]

it is the end of validation epoch
tensor([[[[ 0.9373,  0.9216,  0.9137,  ...,  0.8196,  0.8275,  0.8353],
          [ 0.9373,  0.9294,  0.9137,  ...,  0.8196,  0.8353,  0.8510],
          [ 0.9451,  0.9294,  0.9137,  ...,  0.8275,  0.8431,  0.8588],
          ...,
          [-0.1608, -0.2000, -0.2235,  ..., -0.0902, -0.0745, -0.0824],
          [-0.2549, -0.2784, -0.2784,  ..., -0.0275,  0.0039,  0.0118],
          [-0.2471, -0.2627, -0.2471,  ...,  0.0667,  0.1137,  0.1294]],

         [[ 0.8039,  0.7882,  0.7804,  ...,  0.6863,  0.6941,  0.7020],
          [ 0.8039,  0.7961,  0.7804,  ...,  0.6863,  0.7020,  0.7176],
          [ 0.8118,  0.7961,  0.7804,  ...,  0.6941,  0.7098,  0.7255],
          ...,
          [-0.2784, -0.3176, -0.3412,  ..., -0.2157, -0.2000, -0.2078],
          [-0.3725, -0.3961, -0.3961,  ..., -0.1529, -0.1216, -0.1137],
          [-0.3647, -0.3804, -0.3647,  ..., -0.0588, -0.0118,  0.0039]],

         [[ 0.6000,  0.5843,  0.5765,  ...,  0.5294,  0.5373,  0.545

tensor([[[[-0.2078, -0.0039,  0.2314,  ...,  0.6235,  0.6000,  0.5765],
          [-0.4275, -0.4431,  0.0431,  ...,  0.5843,  0.5686,  0.5451],
          [ 0.0353, -0.2392, -0.0118,  ...,  0.5686,  0.5451,  0.5216],
          ...,
          [ 0.0510, -0.1059, -0.0745,  ..., -0.0118, -0.0980, -0.0353],
          [ 0.0118, -0.1137, -0.0824,  ..., -0.0118, -0.0353,  0.0039],
          [-0.0510, -0.1373, -0.0980,  ...,  0.0431, -0.0118, -0.0275]],

         [[-0.3647, -0.1608,  0.0824,  ...,  0.6078,  0.5686,  0.5451],
          [-0.5843, -0.6000, -0.1059,  ...,  0.5686,  0.5373,  0.5137],
          [-0.1137, -0.3882, -0.1608,  ...,  0.5529,  0.5137,  0.4902],
          ...,
          [-0.1059, -0.2706, -0.2706,  ...,  0.0980,  0.0353,  0.0980],
          [-0.1529, -0.2863, -0.2863,  ...,  0.0980,  0.0980,  0.1451],
          [-0.2157, -0.3176, -0.3020,  ...,  0.1765,  0.1216,  0.1137]],

         [[-0.1529,  0.0353,  0.2549,  ...,  0.5137,  0.4824,  0.4588],
          [-0.3882, -0.4039,  

tensor([[[[ 0.8980,  0.8667,  0.9529,  ..., -0.0902, -0.3804, -0.4745],
          [ 0.8745,  0.9059,  1.0000,  ...,  0.0039, -0.3020, -0.4353],
          [ 0.8980,  0.9765,  1.0000,  ...,  0.0118, -0.2706, -0.3490],
          ...,
          [ 0.5373,  0.2549, -0.0431,  ...,  0.1294,  0.1059,  0.0824],
          [ 0.7804,  0.7961,  0.8353,  ...,  0.2549,  0.1137,  0.0902],
          [ 0.7098,  0.8667,  0.9294,  ...,  0.3569,  0.2235,  0.2471]],

         [[ 1.0000,  0.9529,  0.9843,  ..., -0.3490, -0.4510, -0.4510],
          [ 0.9686,  0.9843,  1.0000,  ..., -0.2784, -0.3961, -0.4275],
          [ 0.9451,  1.0000,  1.0000,  ..., -0.3255, -0.4039, -0.3804],
          ...,
          [ 0.9059,  0.5843,  0.2078,  ..., -0.6078, -0.5922, -0.5686],
          [ 1.0000,  1.0000,  0.9373,  ..., -0.6784, -0.5765, -0.4745],
          [ 0.9529,  1.0000,  0.9608,  ..., -0.6627, -0.4745, -0.2784]],

         [[ 1.0000,  0.9686,  0.7725,  ..., -0.6078, -0.5686, -0.5059],
          [ 1.0000,  0.9922,  

tensor([[[[ 0.0431,  0.1294,  0.2078,  ..., -0.0353, -0.0431, -0.0353],
          [ 0.0588,  0.1294,  0.1922,  ..., -0.0431, -0.0510, -0.0431],
          [ 0.0275,  0.0980,  0.1608,  ..., -0.0431, -0.0431, -0.0510],
          ...,
          [ 0.1137,  0.0431,  0.0039,  ..., -0.0824, -0.0824, -0.0667],
          [ 0.0510,  0.0431,  0.0824,  ..., -0.0745, -0.0118,  0.0275],
          [-0.0667,  0.0118,  0.1765,  ...,  0.0510,  0.1451,  0.0902]],

         [[-0.0118,  0.0980,  0.1686,  ..., -0.0745, -0.0902, -0.0824],
          [ 0.0039,  0.0980,  0.1529,  ..., -0.0824, -0.0980, -0.0902],
          [-0.0275,  0.0588,  0.1216,  ..., -0.1059, -0.1137, -0.1216],
          ...,
          [-0.0980, -0.1451, -0.1843,  ..., -0.1922, -0.1765, -0.1608],
          [-0.1686, -0.1686, -0.1059,  ..., -0.1922, -0.1059, -0.0667],
          [-0.2863, -0.2000, -0.0196,  ..., -0.0667,  0.0510, -0.0039]],

         [[-0.0902,  0.0275,  0.1216,  ...,  0.1059,  0.1294,  0.1373],
          [-0.0745,  0.0275,  

tensor([[[[-0.2392, -0.2471, -0.2471,  ...,  0.0196,  0.0275,  0.0667],
          [-0.1922, -0.2078, -0.2157,  ...,  0.0431,  0.0510,  0.0824],
          [-0.1451, -0.1608, -0.1765,  ...,  0.0510,  0.0353,  0.0431],
          ...,
          [-0.4510, -0.3647, -0.3333,  ..., -0.4118, -0.3020, -0.0588],
          [-0.1529, -0.1373, -0.1686,  ..., -0.3725, -0.2627, -0.1686],
          [-0.0118, -0.0431, -0.0667,  ..., -0.3804, -0.3098, -0.3412]],

         [[ 0.0431,  0.0353,  0.0275,  ...,  0.2235,  0.2314,  0.2706],
          [ 0.0902,  0.0745,  0.0667,  ...,  0.2471,  0.2549,  0.2863],
          [ 0.1216,  0.1059,  0.0902,  ...,  0.2314,  0.2157,  0.2235],
          ...,
          [-0.3098, -0.2235, -0.1765,  ..., -0.3255, -0.2157,  0.0275],
          [-0.0196, -0.0118, -0.0196,  ..., -0.2706, -0.1608, -0.0667],
          [ 0.1137,  0.1059,  0.0745,  ..., -0.2706, -0.2078, -0.2392]],

         [[ 0.2314,  0.2235,  0.2471,  ...,  0.1373,  0.1294,  0.1686],
          [ 0.2784,  0.2627,  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
