<a href="https://colab.research.google.com/github/tansyab1/MBCAESRWC/blob/main/MWCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Connect the **Google Colab** with **Google Drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
print(torch.__version__)

1.7.0+cu101


We will verify that GPU is enabled for this notebook

Following should print: ***CUDA is available!  Training on GPU ...***
 
If it prints otherwise, then you need to enable GPU: 

From **Menu** > **Runtime** > **Change Runtime Type** > **Hardware Accelerator** > **GPU**

In [None]:
import torch
import numpy as np
import time

# check if CUDA is available
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

CUDA is available!  Training on GPU ...


In [None]:
!nvidia-smi

Mon Jan 25 07:12:00 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   60C    P8    11W /  70W |     10MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:

from torch.autograd import Variable
from torchvision import datasets
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torch.distributions import Categorical
import torchvision.transforms as transforms
import argparse
import os
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as io
import math
import numbers
import pywt
from tqdm import tqdm
from matplotlib import pyplot
from ipywidgets import interact
from PIL import Image
from scipy.stats import entropy
from collections import OrderedDict
from skimage.feature import greycomatrix, greycoprops
# ! git clone https://github.com/fbcotter/pytorch_wavelets
# ! cd ./pytorch_wavelets/
# ! pip install ./pytorch_wavelets/
# from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT)



## 1. Define the **Loss function** including 3 main parts:



In [None]:
def loss_fn( I_H, I_train, I_rec, score , y_train):
    recon_loss = F.mse_loss(I_rec, I_train)
    energy_loss = torch.norm(I_H,2)
    classification_loss = F.mse_loss(score, y_train)
    return recon_loss+energy_loss+classification_loss

## 2. Define the **model**
    Arguments:
        split_ratio (int, sequence): ratio of training size/ total size of dataset  
        batch_size (int, sequence): Size of the image batch for training.
        num_class (int, sequence): number of class in dataset
        negative_slope (float, optional): nagative parameter in ReLU function

In [None]:


class VAEGT(nn.Module):
    def __init__(self, split_ratio, batch_size=1000, num_classes=10, negative_slope=0.1):
        super(VAEGT, self).__init__()
        self.batch_size = batch_size
        self.split_ratio = split_ratio
        self.flag1 = 1 #5
        self.flag2 = 1  #4

        self.train_size = int(self.split_ratio*self.batch_size)
        self.test_size = batch_size - self.train_size

        self.num_classes = num_classes
        self.negative_slope = negative_slope

        # Encoder declaration
        self.encoder = nn.Sequential(OrderedDict([
            ('layer1', nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1,padding=1)),
            ('bat1', nn.BatchNorm2d(16)),
            ('relu1', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer2', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1,padding=1)),
            ('bat2', nn.BatchNorm2d(16)),
            ('relu2', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer3', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1,padding=1)),
            ('bat3', nn.BatchNorm2d(16)),
            ('relu3', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
        ]))

        # Decoder declaration
        self.decoderL = nn.Sequential(OrderedDict([
            ('layer0', nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, stride=2)),
            ('bat1', nn.BatchNorm2d(1)),
            ('relu0', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer1',  nn.ConvTranspose2d(in_channels=1, out_channels=16, kernel_size=4, stride=2)),
            ('bat2', nn.BatchNorm2d(16)),
            ('relu1', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer2',  nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1,padding=1)),
            ('bat3', nn.BatchNorm2d(16)),
            ('relu2', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer3',  nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1,padding=1)),
            ('bat4', nn.BatchNorm2d(16)),
            ('relu3', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer4',  nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, stride=1,padding=1)),
            ('bat5', nn.BatchNorm2d(1)),
            ('relu4', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
        ]))

        self.decoderH = nn.Sequential(OrderedDict([
            ('layer0', nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, stride=2)),
            ('bat1', nn.BatchNorm2d(1)),
            ('relu0', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer1',  nn.ConvTranspose2d(in_channels=1, out_channels=16, kernel_size=4, stride=2)),
            ('bat2', nn.BatchNorm2d(16)),
            ('relu1', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer2',  nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1,padding=1)),
            ('bat3', nn.BatchNorm2d(16)),
            ('relu2', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer3',  nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1,padding=1)),
            ('bat4', nn.BatchNorm2d(16)),
            ('relu3', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer4',  nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, stride=1,padding=1)),
            ('bat5', nn.BatchNorm2d(1)),
            ('relu4', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
        ]))

        self.VGG = nn.Sequential(OrderedDict([
            ('layer01', nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3,padding=1)),
            ('bat1', nn.BatchNorm2d(64)),
            ('relu01', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer02', nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,padding=1)),
            ('bat2', nn.BatchNorm2d(64)),
            ('relu02', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('maxpool0', nn.MaxPool2d(kernel_size=2,stride=2)),

            ('layer11', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,padding=1)),
            ('bat3', nn.BatchNorm2d(128)),
            ('relu11', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer12', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,padding=1)),
            ('bat4', nn.BatchNorm2d(128)),
            ('relu12', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('maxpool1', nn.MaxPool2d(kernel_size=2,stride=2)),

            ('layer21', nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3,padding=1)),
            ('bat5', nn.BatchNorm2d(256)),
            ('relu21', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer22', nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,padding=1)),
            ('bat6', nn.BatchNorm2d(256)),
            ('relu22', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer23', nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,padding=1)),
            ('bat7', nn.BatchNorm2d(256)),
            ('relu23', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('maxpool2', nn.MaxPool2d(kernel_size=2,stride=2)),

            ('layer31', nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3,padding=1)),
            ('bat8', nn.BatchNorm2d(512)),
            ('relu31', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer32', nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1)),
            ('bat9', nn.BatchNorm2d(512)),
            ('relu32', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer33', nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1)),
            ('bat10', nn.BatchNorm2d(512)),
            ('relu33', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('maxpool3', nn.MaxPool2d(kernel_size=2,stride=2)),

            ('layer41', nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1)),
            ('bat11', nn.BatchNorm2d(512)),
            ('relu41', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer42', nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1)),
            ('bat12', nn.BatchNorm2d(512)),
            ('relu42', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer43', nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1)),
            ('bat13', nn.BatchNorm2d(512)),
            ('relu43', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('maxpool4', nn.MaxPool2d(kernel_size=2,stride=2)),
            

        ]))

        self.classifer = nn.Sequential(OrderedDict([
            ('linear0', nn.Linear(in_features=3072*self.flag1, out_features=4096)),
            ('relu0', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('dropout0', nn.Dropout(p=0.65)),
            ('linear1', nn.Linear(in_features=4096, out_features=4096)),
            ('relu1', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('dropout1', nn.Dropout(p=0.65)),
            ('linear2', nn.Linear(in_features=4096, out_features=self.num_classes)),

          
        ]))


        self.fusion = nn.Sequential(OrderedDict([
            ('layer01', nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3,padding=1)),
            ('bat1', nn.BatchNorm2d(16)),
            ('relu01', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer02', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3,padding=1)),
            ('bat2', nn.BatchNorm2d(16)),
            ('relu02', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('maxpool0', nn.MaxPool2d(kernel_size=2,stride=2)),

            ('layer11', nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3,padding=1)),
            ('bat3', nn.BatchNorm2d(32)),
            ('relu11', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer12', nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3,padding=1)),
            ('bat4', nn.BatchNorm2d(32)),
            ('relu12', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('maxpool1', nn.MaxPool2d(kernel_size=2,stride=2)),

            ('layer21', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3,padding=1)),
            ('bat5', nn.BatchNorm2d(64)),
            ('relu21', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer22', nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,padding=1)),
            ('bat6', nn.BatchNorm2d(64)),
            ('relu22', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer23', nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,padding=1)),
            ('bat7', nn.BatchNorm2d(64)),
            ('relu23', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('maxpool2', nn.MaxPool2d(kernel_size=2,stride=2)),

            ('layer31', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,padding=1)),
            ('bat8', nn.BatchNorm2d(128)),
            ('relu31', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer32', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,padding=1)),
            ('bat9', nn.BatchNorm2d(128)),
            ('relu32', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer33', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,padding=1)),
            ('bat10', nn.BatchNorm2d(128)),
            ('relu33', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('maxpool3', nn.MaxPool2d(kernel_size=2,stride=2)),

            ('layer41', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,padding=1)),
            ('bat11', nn.BatchNorm2d(128)),
            ('relu41', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer42', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,padding=1)),
            ('bat12', nn.BatchNorm2d(128)),
            ('relu42', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('layer43', nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,padding=1)),
            ('bat13', nn.BatchNorm2d(128)),
            ('relu43', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('maxpool4', nn.MaxPool2d(kernel_size=2,stride=2)),

        ]))

        self.classifer_final = nn.Sequential(OrderedDict([
            ('linear0', nn.Linear(in_features=3072*self.flag1+768*self.flag2, out_features=2048)),
            ('relu0', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('dropout0', nn.Dropout(p=0.65)),
            ('linear1', nn.Linear(in_features=2048, out_features=2048)),
            ('relu1', nn.LeakyReLU(negative_slope=negative_slope, inplace=True)),
            ('dropout1', nn.Dropout(p=0.65)),
            ('linear2', nn.Linear(in_features=2048, out_features=self.num_classes)),

          
        ]))

        self._init_weights()

    def forward(self, I_train, y_train):
        if self.training:

            # Encode input
            x = self.encoder(I_train).cuda()
            I_L = self.decoderL(x).cuda()
            I_H = self.decoderH(x).cuda()
            I_rec = I_L + I_H
            
            x_L = self.VGG(I_L).cuda()
            x_H = self.fusion(I_H).cuda()
            
            x_L = x_L.view(x_L.size(0), -1)
            x_H = x_H.view(x_H.size(0), -1)
            x = torch.cat([x_L, x_H], dim=1)
            score_L= self.classifer(x_L).cuda()
            score_H= self.classifer_final(x).cuda()

            score_final= score_L+score_H
            
            return I_rec, I_H, score_final

    # def _onehot(self, y):
    #     y_onehot = torch.FloatTensor(y.shape[0], self.num_classes)
    #     y_onehot.zero_()
    #     y_onehot.scatter_(1, y.long(), 1)
    #     return y_onehot

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

## 3. Define the **improve checker** to save the best checkpoint

In [None]:
# Accuracy checker: mode "min" for loss, mode "max" for accuracy
class ImproveChecker():
	def __init__(self, mode='min', best_val=None):
		assert mode in ['min', 'max']
		self.mode = mode
		if best_val is not None:
			self.best_val = best_val
		else:
			if self.mode=='min':
				self.best_val = np.inf
			elif self.mode=='max':
				self.best_val = 0.0

	def _check(self, val):
		if self.mode=='min':
			if val < self.best_val:
				print("[%s] Improved from %.4f to %.4f" % (self.__class__.__name__, self.best_val, val))
				self.best_val = val
				return True
			else:
				print("[%s] Not improved from %.4f" % (self.__class__.__name__, self.best_val))
				return False
		else:
			if val > self.best_val:
				print("[%s] Improved from %.4f to %.4f" % (self.__class__.__name__, self.best_val, val))
				self.best_val = val
				return True
			else:
				print("[%s] Not improved from %.4f" % (self.__class__.__name__, self.best_val))
				return False

## 4. Process the **data**
    Arguments:
        mat_path (string, sequence): link of dataset
        split_ratio (int, sequence): split ratio of training size/ total size.
        width, height (int, sequence): image shape

In [None]:
 # Process data
 class MyDataset(Dataset):
    def __init__(self, mat_path,split_ratio,width,height):       
      self.width = width
      self.height = height
      self.split_ratio=split_ratio
      feature = io.loadmat(mat_path,squeeze_me=True)['features']
      Label = io.loadmat(mat_path,squeeze_me=True)['Label']

      self.images = torch.from_numpy(np.transpose(feature)).type(torch.float)
      self.images = torch.reshape(self.images,[len(self.images), self.width, self.height])

      self.images=  torch.unsqueeze(self.images, dim=1)
      
      self.target = torch.from_numpy(np.transpose(Label)).type(torch.long)
      self.data=list(zip(self.images, self.target))
      
      # self.per_image_mse_loss = F.mse_loss(self.init_wt, self.rec_init_wt)
      # print('MSE loss of wavelet transform:',self.per_image_mse_loss)

      self.train_size = int(self.split_ratio*len(self.data))
      self.test_size = len(self.data) - self.train_size
      
    def _generate(self):
      self.train_dataset, self.test_dataset = torch.utils.data.random_split(self.data, [self.train_size, self.test_size])
      return self.train_dataset, self.test_dataset
      
    def __getitem__(self, index):
      
      images=self.images[index]
      target=self.target[index]
      return images, target
     
    def __len__(self):
      return len(self.data)



### 4.1. Declare the link as well as arguments to config the **dataset**


In [None]:
mat_path = ('/content/drive/My Drive/Colab Notebooks/Master Internship/dataset/YaleB/YaleB_96x84.mat')
split_ratio = 0.9
num_classes = 40
batch_size = 100
size= [96,84] #size of input image

custom_data = MyDataset(mat_path,split_ratio=split_ratio, width = size[0], height=size[1])
data_train, data_test = custom_data._generate()
dataloader_train = torch.utils.data.DataLoader(dataset=data_train,
                                           num_workers=4,
                                           batch_size= batch_size,
                                           pin_memory=True)

dataloader_test = torch.utils.data.DataLoader(dataset=data_test,
                                           num_workers=4,
                                           batch_size= batch_size,
                                           pin_memory=True)


## 5. Init the **model**, **optimizer** and **improvechecker**




In [None]:
# Initialize VAE
model = VAEGT(split_ratio=split_ratio, batch_size= custom_data.__len__() , num_classes=50)
model.cuda()

# Optimizers
optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)

# ImproveChecker
improvechecker = ImproveChecker(mode='min')

num_epoch = 1000 #number of training epoch

In [None]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

count_parameters(model)

+--------------------------------+------------+
|            Modules             | Parameters |
+--------------------------------+------------+
|     encoder.layer1.weight      |    144     |
|      encoder.layer1.bias       |     16     |
|      encoder.bat1.weight       |     16     |
|       encoder.bat1.bias        |     16     |
|     encoder.layer2.weight      |    2304    |
|      encoder.layer2.bias       |     16     |
|      encoder.bat2.weight       |     16     |
|       encoder.bat2.bias        |     16     |
|     encoder.layer3.weight      |    2304    |
|      encoder.layer3.bias       |     16     |
|      encoder.bat3.weight       |     16     |
|       encoder.bat3.bias        |     16     |
|     decoderL.layer0.weight     |    144     |
|      decoderL.layer0.bias      |     1      |
|      decoderL.bat1.weight      |     1      |
|       decoderL.bat1.bias       |     1      |
|     decoderL.layer1.weight     |    256     |
|      decoderL.layer1.bias      |     1

57398432

##  8.2. Start training

In [None]:
# Training process

model.train()
for epoch in range(1, num_epoch):
    correct_pred = 0
    running_loss = 0.0
    start = time.perf_counter()
    # Training
    for index, (inputs_train , labels_train) in enumerate(dataloader_train): # Dataloader for training set
      # inputs_train = inputs_train.unsqueeze(1).cuda()
      inputs_train = inputs_train.cuda()
      # hh_coefs_train = hh_coefs_train.cuda()
      labels_train = labels_train.view(-1, 1).cuda()
      y_onehot_train = torch.FloatTensor(inputs_train.shape[0], 50).cuda()
      y_onehot_train.zero_()
      y_onehot_train.scatter_(1, labels_train, 1).cuda()


      optimizer.zero_grad()
      I_rec, I_H, score_final = model(inputs_train, y_onehot_train)
      print(y_onehot_train[0])
      

      _, pred = torch.max(score_final, 1)
      # print(labels_train.shape)
      correct_pred += (torch.reshape(pred,[pred.shape[0],1]) == labels_train).sum()

      loss = loss_fn( I_H, inputs_train, I_rec, score_final , y_onehot_train)
      loss.backward()
      optimizer.step()
              
      running_loss += float(loss.item())
    end = time.perf_counter()
    print('epoch {}/{}\tTrain loss: {:.4f}\tTrain accuracy: {:.2f}%'.
        format(epoch + 1, num_epoch, running_loss / (index + 1), correct_pred.item() / (batch_size * (index + 1)) * 100))
    print('Time: {:.2f}s'.format(end - start))




tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tens

KeyboardInterrupt: ignored