## This code is written in Pytorch for the DISK project (WITH Labels). 
@author: Sayantan
date: 12 October

In [60]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

import sys
import platform
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import glob
import os
import csv
import re ## For data manipulation

from PIL import Image, ImageOps

has_gpu = torch.cuda.is_available()
has_mps = torch.backends.mps.is_built()
device = "mps" if torch.backends.mps.is_built() \
    else "cuda:0" if torch.cuda.is_available() else "cpu"

print(f"Python Platform: {platform.platform()}")
print(f"PyTorch Version: {torch.__version__}")

# print(f"Scikit-Learn {sk.__version__}")
print("GPU is", "available" if has_gpu else "NOT AVAILABLE")
print("MPS (Apple Metal) is", "AVAILABLE" if has_mps else "NOT AVAILABLE")
print(f"Target device is {device}")



Python Platform: macOS-13.6-arm64-arm-64bit
PyTorch Version: 2.1.0.dev20230623
GPU is NOT AVAILABLE
MPS (Apple Metal) is AVAILABLE
Target device is mps


In [61]:
parser = argparse.ArgumentParser()
parser.add_argument('-f')
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=6, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=64, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between image sampling")
opt = parser.parse_args()
print(opt)

Namespace(f='/Users/auddy/Library/Jupyter/runtime/kernel-f97f9b9c-a01c-4a83-a19d-140392197a63.json', n_epochs=50, batch_size=64, lr=0.0002, b1=0.5, b2=0.999, n_cpu=8, latent_dim=100, n_classes=6, img_size=64, channels=1, sample_interval=100)


In [62]:
## Create a CSV file with Planet_Mass as labels and 
path = './'
Data_path = 'analysis_VAE_20JUlY2023/Disk_gas_plots/'
os.makedirs("analysis_data", exist_ok=True)
list_RT_path = glob.glob(path + Data_path + 'gas_gap*') ## make a list of all the RT folder where each folder is for each sim
list_sorted_RT_path  = sorted(list_RT_path, key=lambda f: [int(n) for n in re.findall(r"\d+", f)][-1]) ## sorting the images
df_images_folder =pd.DataFrame(list_sorted_RT_path,columns=["image_path"])
df_images_folder.to_csv(path+'./analysis_data/df_images_folder.csv')

# merging the csv with the simulation parameters files
joined_files = os.path.join("./analysis_VAE_20JUlY2023/", "Disk_gap_param*.csv")
  # A list of all joined files is returned
joined_list = glob.glob(joined_files)
# Finally, the files are joined
parameters_df = pd.concat(map(pd.read_csv, joined_list))
parameters_df['Planet_Mass'] = parameters_df['Planet_Mass'] / (3 * 10**-6) ## planet mass in Earth Mass units
# parameters_df
sorted_parmeters_df= parameters_df.sort_values('Sample#')
# # sorted_parmeters_df.to_csv(path+'./analysis_VAE_20JUlY2023/sorted_parmeters_df.csv')
complete_dataset = pd.concat([sorted_parmeters_df.reset_index(drop=True),df_images_folder],axis=1)
complete_dataset
complete_dataset.to_csv(path+'./analysis_data/complete_dataset.csv')
complete_dataset_mod = complete_dataset.drop(columns=['Sample#', 'SigmaSlope','Alpha'])

In [63]:
## Lets take a look at the data: As a first step we will form 
## 5 class by binning the planet masses in different range

# complete_dataset_mod
# complete_dataset_mod['Planet_Mass'].max()
# complete_dataset_mod['Planet_Mass'].min()

# Define the bin edges to create 5 classes
bin_edges = [0, 50, 100, 150, 200, 250,300]
labels = [1, 2, 3, 4, 5,6]
complete_dataset_mod['classes'] = pd.cut(complete_dataset_mod['Planet_Mass'], bins=bin_edges, labels=labels)
complete_dataset_mod

Unnamed: 0,Planet_Mass,image_path,classes
0,10.133333,./analysis_VAE_20JUlY2023/Disk_gas_plots/gas_g...,1
1,281.000000,./analysis_VAE_20JUlY2023/Disk_gas_plots/gas_g...,6
2,38.666667,./analysis_VAE_20JUlY2023/Disk_gas_plots/gas_g...,1
3,272.666667,./analysis_VAE_20JUlY2023/Disk_gas_plots/gas_g...,6
4,243.000000,./analysis_VAE_20JUlY2023/Disk_gas_plots/gas_g...,5
...,...,...,...
995,135.000000,./analysis_VAE_20JUlY2023/Disk_gas_plots/gas_g...,3
996,156.666667,./analysis_VAE_20JUlY2023/Disk_gas_plots/gas_g...,4
997,133.333333,./analysis_VAE_20JUlY2023/Disk_gas_plots/gas_g...,3
998,81.333333,./analysis_VAE_20JUlY2023/Disk_gas_plots/gas_g...,2


In [64]:
## Splitting the data for training and testing 
from sklearn.model_selection import train_test_split
split_1 = train_test_split(complete_dataset_mod, test_size=0.15, random_state=42)
(train, test) = split_1

In [6]:
# print(train['classes'].iloc[102])
# print(train['image_path'].iloc[102])
# print(train['Planet_Mass'].iloc[102])

In [65]:
## Defining a custom class that returns image tensor and the corresponding label for a given index
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self,dataframe,transform = None):
        self.df = dataframe
        self.transform = transform
    def __len__(self):
        return len(self.df)
    def __getitem__(self, index):
        filename = self.df["image_path"].iloc[index]
        label = self.df["classes"].iloc[index]
        image = Image.open(filename)
        image = image.convert('L') ## Converting to gray-scale with one channel
#         print("THE SHAPE OF THE IMAGE",np.shape(image))
        left = 105
        top = 45
        right = 560
        bottom = 500 
        # Cropped image of above dimension" # (It will not change original image)
        image = image.crop((left, top, right, bottom))
        if self.transform is not None:
            image = self.transform(image)
#         print(np.shape(image))
        return image,label

In [66]:
transform_custom = transforms.Compose([transforms.Resize(opt.img_size),
   transforms.ToTensor()])
# image_output = CustomDataset(train,transform= transform_custom)
image_output = CustomDataset(train,transform = None)
image,label = image_output.__getitem__(102)
print(np.shape(image),label)
# gray_image = ImageOps.grayscale(image)
# print(type(gray_image))
# image.show()
# print(np.shape(gray_image))
# # print(image)
# gray_image.show()
# image.shape()

(455, 455) 6


In [9]:
# Parameters
params = {'batch_size': 64,
          'shuffle': True}
#           'num_workers': 1}
# max_epochs = 2

In [67]:
# def get_mean_std(loader):
#     # Compute the mean and standard deviation of all pixels in the dataset
#     num_pixels = 0
#     mean = 0.0
#     std = 0.0
#     for images, _ in loader:
#         batch_size, num_channels, height, width = images.shape
#         num_pixels += batch_size * height * width
#         mean += images.mean(axis=(0, 2, 3)).sum()
#         std += images.std(axis=(0, 2, 3)).sum()

#     mean /= num_pixels
#     std /= num_pixels

#     return mean, std

In [68]:
training_set = CustomDataset(train,transform=transform_custom)
training_generator = torch.utils.data.DataLoader(training_set, **params)
mean, std = get_mean_std(training_generator)

## This time normalizing the images
# transform_custom = transforms.Compose([transforms.Resize(32),
#    transforms.ToTensor(),transforms.Normalize(mean=mean, std=std)
# ])

# training_set = CustomDataset(train,transform=transform_custom)
# training_generator = torch.utils.data.DataLoader(training_set, **params)

In [69]:
# opt.channels, opt.img_size, opt.img_size = 1 ,32,32
img_shape = (opt.channels, opt.img_size, opt.img_size)
print(img_shape)

(1, 64, 64)


In [70]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim + opt.n_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
         # Concatenate label embedding and image to produce input
        print(np.shape(self.label_emb(labels)),np.shape(noise))
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        print(np.shape(gen_input))
        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img

    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)


        self.model = nn.Sequential(
            nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

In [71]:
# Loss functions
adversarial_loss = torch.nn.MSELoss()
# adversarial_loss = torch.nn.BCELoss()
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# if cuda:
generator.to(device)
discriminator.to(device)
adversarial_loss.to(device)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

In [73]:
def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(torch.FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))).to(device))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num+1 for _ in range(n_row) for num in range(n_row)])
    labels = Variable(torch.LongTensor(labels).to(device))
    print(labels)
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
os.makedirs("images", exist_ok=True)
sample_image(n_row=3, batches_done=20)

tensor([1, 2, 3, 1, 2, 3, 1, 2, 3], device='mps:0')
torch.Size([9, 6]) torch.Size([9, 100])
torch.Size([9, 106])


In [98]:
os.makedirs("images", exist_ok=True)
for epoch in range(opt.n_epochs):
    # Training
    for i, (imgs, labels) in enumerate(training_generator):
        print(type(imgs))
        batch_size = imgs.shape[0]
        print("Batch_size",batch_size)
        # Adversarial ground truths
        valid = Variable(torch.FloatTensor(batch_size, 1).fill_(1.0).to(device), requires_grad=False)
        fake = Variable(torch.FloatTensor(batch_size, 1).fill_(0.0).to(device), requires_grad=False)
        
        # Configure input
        real_imgs = Variable(imgs.type(torch.FloatTensor).to(device))
        labels = Variable(labels.type(torch.LongTensor).to(device))
#         print(real_imgs)
        
        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()
#         np.random.seed(1234)
        # Sample noise and labels as generator input
        z = Variable(torch.FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))).to(device))
        gen_labels = Variable(torch.LongTensor(np.random.randint(1, opt.n_classes, batch_size)).to(device))
        print(gen_labels)
        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)



        # Loss measures generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss   = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()
        
        
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)

        # Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(training_generator), d_loss.item(), g_loss.item())
        )


        batches_done = epoch * len(training_generator) + i
#         print(epoch,batches_done,len(training_generator))
        print(np.shape(gen_imgs.data))
        if batches_done % opt.sample_interval == 0:
#             save_image(gen_imgs.data[:2], "images/%d.png" % batches_done, nrow=3, normalize=True)
            sample_image(n_row=5, batches_done=batches_done)


<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 0/50] [Batch 0/14] [D loss: 0.249540] [G loss: 0.287261]
torch.Size([64, 1, 64, 64])
tensor([1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4,
        5], device='mps:0')
torch.Size([25, 6]) torch.Size([25, 100])
torch.Size([25, 106])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 0/50] [Batch 1/14] [D loss: 0.244641] [G loss: 0.271095]
torc

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 1/50] [Batch 6/14] [D loss: 0.222548] [G loss: 0.245807]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 1/50] [Batch 7/14] [D loss: 0.237788] [G loss: 0.261317]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4,

tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 2/50] [Batch 12/14] [D loss: 0.244055] [G loss: 0.663301]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 18
tensor([4, 2, 5, 4, 4, 3, 1, 1, 1, 5, 2, 1, 2, 3, 4, 5, 1, 2], device='mps:0')
torch.Size([18, 6]) torch.Size([18, 100])
torch.Size([18, 106])
[Epoch 2/50] [Batch 13/14] [D loss: 0.228496] [G loss: 0.255857]
torch.Size([18, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Ep

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 4/50] [Batch 5/14] [D loss: 0.222705] [G loss: 0.376306]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 4/50] [Batch 6/14] [D loss: 0.218236] [G loss: 0.392510]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4,

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 5/50] [Batch 11/14] [D loss: 0.259994] [G loss: 0.342101]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 5/50] [Batch 12/14] [D loss: 0.240455] [G loss: 0.356500]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 18
tensor([4, 2, 5, 4, 4, 3, 1, 1, 1, 5, 2, 1, 2, 3, 4, 5, 1, 2], device='mps:0')
torch.Size([18, 6]) torch.Size

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 7/50] [Batch 3/14] [D loss: 0.242814] [G loss: 0.302414]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 7/50] [Batch 4/14] [D loss: 0.237705] [G loss: 0.334449]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4,

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 8/50] [Batch 9/14] [D loss: 0.272341] [G loss: 0.226071]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 8/50] [Batch 10/14] [D loss: 0.263827] [G loss: 0.266885]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 10/50] [Batch 1/14] [D loss: 0.227306] [G loss: 0.327966]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 10/50] [Batch 2/14] [D loss: 0.254399] [G loss: 0.197950]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 11/50] [Batch 7/14] [D loss: 0.248892] [G loss: 0.259641]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 11/50] [Batch 8/14] [D loss: 0.241878] [G loss: 0.261534]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 13/50] [Batch 0/14] [D loss: 0.237537] [G loss: 0.249246]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 13/50] [Batch 1/14] [D loss: 0.226495] [G loss: 0.322000]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 14/50] [Batch 6/14] [D loss: 0.255269] [G loss: 0.223175]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 14/50] [Batch 7/14] [D loss: 0.254467] [G loss: 0.244492]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 15/50] [Batch 12/14] [D loss: 0.253997] [G loss: 0.261495]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 18
tensor([4, 2, 5, 4, 4, 3, 1, 1, 1, 5, 2, 1, 2, 3, 4, 5, 1, 2], device='mps:0')
torch.Size([18, 6]) torch.Size([18, 100])
torch.Size([18, 106])
[Epoch 15/50] [Batch 13/14] [D loss: 0.245658] [G loss: 0.269189]
torch.Size([18, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 17/50] [Batch 4/14] [D loss: 0.229322] [G loss: 0.346898]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 17/50] [Batch 5/14] [D loss: 0.226190] [G loss: 0.357710]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 18/50] [Batch 10/14] [D loss: 0.249920] [G loss: 0.269050]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 18/50] [Batch 11/14] [D loss: 0.258969] [G loss: 0.268094]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5,

tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 20/50] [Batch 2/14] [D loss: 0.226581] [G loss: 0.298095]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 20/50] [Batch 3/14] [D loss: 0.241445] [G loss: 0.273162]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 21/50] [Batch 8/14] [D loss: 0.246249] [G loss: 0.312490]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 21/50] [Batch 9/14] [D loss: 0.239393] [G loss: 0.334697]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 23/50] [Batch 0/14] [D loss: 0.231224] [G loss: 0.282843]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 23/50] [Batch 1/14] [D loss: 0.224552] [G loss: 0.306646]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 24/50] [Batch 6/14] [D loss: 0.191804] [G loss: 0.271404]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 24/50] [Batch 7/14] [D loss: 0.207570] [G loss: 0.422066]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 25/50] [Batch 12/14] [D loss: 0.258125] [G loss: 0.220426]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 18
tensor([4, 2, 5, 4, 4, 3, 1, 1, 1, 5, 2, 1, 2, 3, 4, 5, 1, 2], device='mps:0')
torch.Size([18, 6]) torch.Size([18, 100])
torch.Size([18, 106])
[Epoch 25/50] [Batch 13/14] [D loss: 0.257861] [G loss: 0.230057]
torch.Size([18, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Si

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 27/50] [Batch 4/14] [D loss: 0.343855] [G loss: 0.658112]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 27/50] [Batch 5/14] [D loss: 0.452065] [G loss: 0.008528]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 28/50] [Batch 10/14] [D loss: 0.256154] [G loss: 0.263868]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 28/50] [Batch 11/14] [D loss: 0.246795] [G loss: 0.276254]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5

tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 30/50] [Batch 2/14] [D loss: 0.211742] [G loss: 0.446967]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 30/50] [Batch 3/14] [D loss: 0.259973] [G loss: 0.124558]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4

tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 31/50] [Batch 8/14] [D loss: 0.255258] [G loss: 0.254874]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 31/50] [Batch 9/14] [D loss: 0.260193] [G loss: 0.251391]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 33/50] [Batch 0/14] [D loss: 0.239578] [G loss: 0.287587]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 33/50] [Batch 1/14] [D loss: 0.233403] [G loss: 0.271724]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 34/50] [Batch 6/14] [D loss: 0.265314] [G loss: 0.239376]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 34/50] [Batch 7/14] [D loss: 0.265248] [G loss: 0.278882]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 35/50] [Batch 12/14] [D loss: 0.232054] [G loss: 0.326521]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 18
tensor([4, 2, 5, 4, 4, 3, 1, 1, 1, 5, 2, 1, 2, 3, 4, 5, 1, 2], device='mps:0')
torch.Size([18, 6]) torch.Size([18, 100])
torch.Size([18, 106])
[Epoch 35/50] [Batch 13/14] [D loss: 0.209213] [G loss: 0.310793]
torch.Size([18, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Si

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 37/50] [Batch 4/14] [D loss: 0.225251] [G loss: 0.317455]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 37/50] [Batch 5/14] [D loss: 0.240127] [G loss: 0.284471]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 38/50] [Batch 10/14] [D loss: 0.279216] [G loss: 0.179556]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 38/50] [Batch 11/14] [D loss: 0.263029] [G loss: 0.231850]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5,

tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 40/50] [Batch 2/14] [D loss: 0.235436] [G loss: 0.627200]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 40/50] [Batch 3/14] [D loss: 0.325667] [G loss: 0.067439]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4

tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 41/50] [Batch 8/14] [D loss: 0.271574] [G loss: 0.268202]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 41/50] [Batch 9/14] [D loss: 0.257069] [G loss: 0.197070]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 43/50] [Batch 0/14] [D loss: 0.242322] [G loss: 0.272023]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 43/50] [Batch 1/14] [D loss: 0.255034] [G loss: 0.266531]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 44/50] [Batch 6/14] [D loss: 0.240403] [G loss: 0.355667]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 44/50] [Batch 7/14] [D loss: 0.234017] [G loss: 0.234002]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 45/50] [Batch 12/14] [D loss: 0.242826] [G loss: 0.297715]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 18
tensor([4, 2, 5, 4, 4, 3, 1, 1, 1, 5, 2, 1, 2, 3, 4, 5, 1, 2], device='mps:0')
torch.Size([18, 6]) torch.Size([18, 100])
torch.Size([18, 106])
[Epoch 45/50] [Batch 13/14] [D loss: 0.235695] [G loss: 0.295681]
torch.Size([18, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 47/50] [Batch 4/14] [D loss: 0.244680] [G loss: 0.245529]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 47/50] [Batch 5/14] [D loss: 0.242388] [G loss: 0.257686]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 

<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 48/50] [Batch 10/14] [D loss: 0.424455] [G loss: 0.020555]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5, 4, 5, 1, 5, 4, 2, 3, 1, 3, 2, 4, 5, 4, 1, 2, 2, 1,
        4, 3, 1, 3, 5, 3, 2, 3, 3, 3, 5, 5, 1, 4, 1, 2], device='mps:0')
torch.Size([64, 6]) torch.Size([64, 100])
torch.Size([64, 106])
[Epoch 48/50] [Batch 11/14] [D loss: 0.296062] [G loss: 0.088488]
torch.Size([64, 1, 64, 64])
<class 'torch.Tensor'>
Batch_size 64
tensor([5, 3, 1, 1, 4, 3, 1, 4, 4, 2, 5, 2, 4, 1, 4, 4, 1, 5, 4, 5, 4, 4, 1, 2,
        2, 5, 4, 4, 4, 5, 5

In [17]:
# def sample_image(n_row, batches_done):
#     """Saves a grid of generated digits ranging from 0 to n_classes"""
#     # Sample noise
#     z = Variable(torch.FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))).to(device))
#     # Get labels ranging from 0 to n_classes for n rows
# #     labels = np.array([num for _ in range(n_row) for num in range(n_row)])
# #     labels = Variable(torch.LongTensor(labels).to(device))
#     gen_imgs = generator(z)
#     print(np.shape(gen_imgs))
#     save_image(gen_imgs.data, "images/test%d.png" % batches_done, nrow=n_row, normalize=True)
# os.makedirs("images", exist_ok=True)
# sample_image(n_row=3, batches_done=20)

In [81]:
batch_size = 2
z = Variable(torch.FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))).to(device))
gen_labels = Variable(torch.LongTensor(np.random.randint(1, opt.n_classes, batch_size)).to(device))
# gen_labels = Variable(torch.LongTensor(np.arrange(1, opt.n_classes, batch_size)).to(device))
# Generate a batch of images
gen_imgs = generator(z, gen_labels)

torch.Size([2, 6]) torch.Size([2, 100])
torch.Size([2, 106])


In [48]:
np.shape(z)
# save_image(gen_imgs.data, "images/test%d.png" % batches_done, nrow=n_row, normalize=True)

torch.Size([1, 100])

In [82]:
np.shape(gen_labels)

torch.Size([2])

In [80]:
gen_labels

tensor([4, 2], device='mps:0')

In [33]:
np.random.seed(1234)
z = Variable(torch.FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))).to(device))

In [36]:
np.shape(z)

torch.Size([2, 100])

In [84]:
test_class = np.asarray([1,2,3])

In [85]:
test_class

array([1, 2, 3])

In [100]:
batch_size = 3
np.random.seed(1234)
z = Variable(torch.FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))).to(device))
gen_labels = Variable(torch.LongTensor(test_class).to(device))
# gen_labels = Variable(torch.LongTensor(np.arrange(1, opt.n_classes, batch_size)).to(device))
# Generate a batch of images
gen_imgs = generator(z, gen_labels)
save_image(gen_imgs.data, "images/test1%d.png" % batches_done, nrow=n_row, normalize=True)

torch.Size([3, 6]) torch.Size([3, 100])
torch.Size([3, 106])
