In [1]:
import torch
from torch import optim
from torch import Tensor
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as tt
import torch.nn as nn
import torch.nn.functional as F 
from torch.utils.data import random_split

from utils import *
import config
import random

from typing import Type

from Network_model import Generator, ConvModel

from tqdm import tqdm, trange
import copy 

from Network_model import Generator, Discriminator

In [2]:
print_config()

RANDOM_SEED   :  11042004
DATA_DIR      :    ./data
USED_DATA     :     MNIST
NUM_LABELLED  :       100
DEVICE        :    cuda:0
EPOCHS        :        20
BATCH_SIZE    :        32
LEARNING_RATE :    0.0002
SCHED         :     False
GAN_BATCH_SIZE:       128


In [3]:
set_random_seed(config.RANDOM_SEED)
random.seed(config.RANDOM_SEED)

Setting seeds ...... 



In [4]:
name = "MarginGAN"

In [5]:
if config.USED_DATA == "CIFAR10":
	mean = [0.5]*3
	std = [0.5]*3

	train_tfm = tt.Compose([
		tt.Resize(32),
		tt.RandomCrop(32, padding=4, padding_mode='edge'),
		tt.RandomHorizontalFlip(),
		tt.Normalize(mean, std, inplace=True)
	])

if config.USED_DATA == "MNIST":
	mean = [0.5]
	std = [0.5]
	train_tfm = tt.Compose([
		tt.Resize(32),
		tt.Normalize(mean, std, inplace=True)
	])

test_tfm = tt.Compose([
	tt.Resize(32),
	tt.Normalize(mean, std)
])

In [6]:
train_ds, test_ds, classes = load_data(train_tfm, test_tfm)

In [7]:
X_full = train_ds.x
y_full = train_ds.y

In [8]:
print(X_full[0])

tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,

In [9]:
classes

['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

In [10]:
n_classes = len(classes)
channels = train_ds[0][0].shape[0] # MNIST
n_classes, channels

(10, 1)

In [11]:
X_sup, y_sup, X_unsup, _ = supervised_samples(X_full, y_full, config.NUM_LABELLED, n_classes, get_unsup=True)

In [12]:
test_dl = CreateDataLoader(test_ds, batch_size=512, transform=test_tfm, device=config.DEVICE)

In [13]:
class Classifier(nn.Module):
    def __init__(self, in_channels, n_classes) -> None:
        super().__init__()

        self.Conv = ConvModel(in_channels)

        self.dropout = nn.Dropout(0.5)

        self.out = nn.Linear(512, n_classes)
        
    def forward(self, X: Tensor):
        out = self.Conv(X)
        out = self.dropout(out)
        out = self.out(out) 
        return out

In [14]:
class ICE(nn.Module):
	def __init__(self):
		super().__init__()

	
	def forward(self, preds, labels):

		outs = torch.softmax(preds, dim = 1)[:, labels]

		outs = -torch.log(1 - outs)
		return outs.mean()


In [15]:
class MarginGAN: 
    def __init__(self, latent_size, n_channels, n_classes, device): 
        self.latent_size = latent_size
        self.n_classes = n_classes
        self.generator = Generator(latent_size, n_channels)
        self.classifier = Classifier(n_channels, n_classes)

        self.CEloss = nn.CrossEntropyLoss() 
      
        self.ice = ICE()
        self.device = device 
        self.to(device)
    
    def to(self, device): 
        self.generator.to(device) 
        self.classifier.to(device)
    
    def load_gen_state_dict(self, file):
        self.generator.load_state_dict(torch.load(file))
    
    def accuracy(self, test_dl): 
        corrected = 0
        for b in tqdm(test_dl):
            images, y = b
            outs = self.classifier.forward(images)
            outs = torch.argmax(outs[:, :-1], dim=1)
            corrected += (outs == y).sum().item()
        return corrected / test_dl.num_data()
    
    
    def classifier_step(self, sup_imgs, sup_labels, unsup_imgs, batch_size): 
        # Loss for labeled samples
        sup_outs = self.classifier(sup_imgs) 
        sup_loss = self.CEloss(sup_outs, sup_labels)

        # Loss for unlabeled samples
        # Pseudo_label:  Pick up the class which
        # has maximum predicted probability for each unlabeled
        # sample
        unsup_outs = self.classifier(unsup_imgs) 
        unsup_pseudolabels = torch.argmax(unsup_outs, dim = 1) 
        # print(unsup_pseudolabels.shape)
        unsup_loss = self.CEloss(unsup_outs, unsup_pseudolabels)

        # Loss for generated samples. Also pseudo_labelling as for 
        # unsup imgs, but now apply the inverted binary cross entropy 
        # as loss. Aim: decrease the margin of these data points
        # and make the prediction distribution flat
        z = torch.randn([batch_size, self.latent_size, 1, 1], device = self.device)
        fake_imgs = self.generator(z) 
        fake_outs = self.classifier(fake_imgs)
        fake_pseudolabels = torch.argmax(fake_outs, dim = 1) 
        fake_loss = self.ice(fake_outs, fake_pseudolabels) 

        return sup_loss + unsup_loss + fake_loss

    def fit(self, epochs, batch_size, batch_per_epoch, dis_lr, max_lr, sup_ds:CustomDataSet, unsup_ds:CustomDataSet, full_ds:CustomDataSet, test_dl, optim:Type[optim.Optimizer], weight_decay = 0, sched = True, PATH = ".", save_best = False, grad_clip = False): 
        # optimizerD = optim(self.discriminator.parameters(), lr = dis_lr)
        optimizerC = optim(self.classifier.parameters(), lr = max_lr, weight_decay = weight_decay)

        if sched: 
            OneCycleLR = torch.optim.lr_scheduler.OneCycleLR(optimizerC, max_lr, epochs=epochs, steps_per_epoch=batch_per_epoch)
        with open('acc.txt', 'w') as f:
            # self.discriminator.train()
            self.classifier.train() 
            best_val = 0 
            for epoch in (range(epochs)):
                for i in (trange(batch_per_epoch)): 
                    sup_imgs, labels = random_split(sup_ds, [batch_size, len(sup_ds) - batch_size])[0][:]
                    full_imgs = random_split(full_ds, [batch_size, len(full_ds) - batch_size])[0][:]
                    unsup_imgs = random_split(unsup_ds, [batch_size, len(unsup_ds) - batch_size])[0][:]

                    # train classifier
                    C_loss = self.classifier_step(sup_imgs.to(self.device), labels.to(self.device), unsup_imgs.to(self.device), batch_size) 
                    C_loss.backward()
                    
                    # if grad_clip: 
                    #     torch.nn.utils.clip_grad_value_(self.classifier.parameters(), 0.1)
                    
                    optimizerC.step()
                    optimizerC.zero_grad()
                    if sched: 
                        OneCycleLR.step()
                    
                    tqdm.write(f'C_loss: {C_loss.detach().item()}', end = "\r")
                    
                # self.discriminator.eval()
                self.classifier.eval()
                acc = self.accuracy(test_dl)
                if acc > best_val: 
                    best_val = acc 
                    if save_best: 
                        torch.save(self.classifier.state_dict(), PATH)
                f.write(f'accuracy: {best_val}\n')
                tqdm.write(f'accuracy: {best_val}', end = "\r")

In [16]:
mGAN = MarginGAN(100, channels, n_classes, config.DEVICE)

In [17]:
mGAN.classifier

Classifier(
  (Conv): ConvModel(
    (initial): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (Conv): Sequential(
      (0): ConvBn(
        (Conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (1): ConvBn(
        (Conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU(inplace=True)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (2): ConvBn(
        (Conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Bn): BatchNorm2d(512, eps=1e-05

In [18]:
i = 4
mGAN.load_gen_state_dict(f"DCGAN/{config.USED_DATA}/netG_epoch_00{i}.pth")

In [19]:
sup_ds = CustomDataSet(X_sup, y_sup, train_tfm)
full_ds = CustomDataSet(X_full, None, train_tfm)
unsup_ds = CustomDataSet(X_unsup, None, train_tfm)

In [20]:
mGAN.fit(10, 64, 50, 1e-5, 1e-4, sup_ds, unsup_ds, full_ds, test_dl, optim.RMSprop, save_best = False, grad_clip=False)

  4%|▍         | 2/50 [00:00<00:20,  2.34it/s]

C_loss: 4.3105902671813965

  8%|▊         | 4/50 [00:01<00:11,  3.94it/s]

C_loss: 3.679995059967041

 12%|█▏        | 6/50 [00:01<00:08,  4.90it/s]

C_loss: 3.057500123977661

 16%|█▌        | 8/50 [00:01<00:07,  5.36it/s]

C_loss: 3.1954281330108643

 20%|██        | 10/50 [00:02<00:07,  5.65it/s]

C_loss: 2.5958189964294434

 24%|██▍       | 12/50 [00:02<00:06,  5.77it/s]

C_loss: 2.7134156227111816

 28%|██▊       | 14/50 [00:02<00:06,  5.84it/s]

C_loss: 2.3355040550231934

 32%|███▏      | 16/50 [00:03<00:05,  5.86it/s]

C_loss: 2.300259828567505

 36%|███▌      | 18/50 [00:03<00:05,  5.91it/s]

C_loss: 2.1485767364501953

 40%|████      | 20/50 [00:04<00:05,  5.94it/s]

C_loss: 2.0873754024505615

 44%|████▍     | 22/50 [00:04<00:04,  5.93it/s]

C_loss: 1.8133827447891235

 48%|████▊     | 24/50 [00:04<00:04,  5.92it/s]

C_loss: 1.7688980102539062

 52%|█████▏    | 26/50 [00:05<00:04,  5.93it/s]

C_loss: 1.4029350280761719

 56%|█████▌    | 28/50 [00:05<00:03,  5.94it/s]

C_loss: 1.4122278690338135

 60%|██████    | 30/50 [00:05<00:03,  5.92it/s]

C_loss: 1.3185510635375977

 64%|██████▍   | 32/50 [00:06<00:03,  5.90it/s]

C_loss: 1.1591759920120242

 68%|██████▊   | 34/50 [00:06<00:02,  5.89it/s]

C_loss: 1.0523424148559574

 72%|███████▏  | 36/50 [00:06<00:02,  5.91it/s]

C_loss: 1.1320600509643555

 76%|███████▌  | 38/50 [00:07<00:02,  5.91it/s]

C_loss: 1.0337651968002328

 80%|████████  | 40/50 [00:07<00:01,  5.89it/s]

C_loss: 0.8706068992614746

 84%|████████▍ | 42/50 [00:07<00:01,  5.93it/s]

C_loss: 0.9334814548492432

 88%|████████▊ | 44/50 [00:08<00:01,  5.90it/s]

C_loss: 0.7810992002487183

 92%|█████████▏| 46/50 [00:08<00:00,  5.90it/s]

C_loss: 0.6858808398246765

 96%|█████████▌| 48/50 [00:08<00:00,  5.91it/s]

C_loss: 0.6641595959663391

100%|██████████| 50/50 [00:09<00:00,  5.50it/s]


C_loss: 0.7569795846939087

100%|██████████| 20/20 [00:01<00:00, 13.25it/s]


accuracy: 0.7288

  2%|▏         | 1/50 [00:00<00:08,  6.00it/s]

C_loss: 0.593142032623291

  4%|▍         | 2/50 [00:00<00:07,  6.03it/s]

C_loss: 0.5464442372322083

  6%|▌         | 3/50 [00:00<00:07,  5.99it/s]

C_loss: 0.6819764375686646

  8%|▊         | 4/50 [00:00<00:07,  6.03it/s]

C_loss: 0.9631919860839844

 10%|█         | 5/50 [00:00<00:07,  6.05it/s]

C_loss: 1.3359291553497314

 12%|█▏        | 6/50 [00:00<00:07,  6.06it/s]

C_loss: 0.79175865650177

 14%|█▍        | 7/50 [00:01<00:07,  6.05it/s]

C_loss: 1.6184756755828857

 16%|█▌        | 8/50 [00:01<00:06,  6.06it/s]

C_loss: 1.6345750093460083

 18%|█▊        | 9/50 [00:01<00:06,  6.06it/s]

C_loss: 1.2757965326309204

 20%|██        | 10/50 [00:01<00:06,  6.04it/s]

C_loss: 2.3360841274261475

 22%|██▏       | 11/50 [00:01<00:06,  6.07it/s]

C_loss: 8.740842819213867

 24%|██▍       | 12/50 [00:01<00:06,  6.06it/s]

C_loss: 2.594076633453369

 26%|██▌       | 13/50 [00:02<00:06,  6.03it/s]

C_loss: 3.02414870262146

 28%|██▊       | 14/50 [00:02<00:05,  6.04it/s]

C_loss: 5.369400978088379

 30%|███       | 15/50 [00:02<00:05,  6.00it/s]

C_loss: 3.94326114654541

 32%|███▏      | 16/50 [00:02<00:05,  6.01it/s]

C_loss: 3.041961431503296

 34%|███▍      | 17/50 [00:02<00:05,  6.03it/s]

C_loss: 2.5806162357330322

 36%|███▌      | 18/50 [00:02<00:05,  6.02it/s]

C_loss: 2.18591046333313

 38%|███▊      | 19/50 [00:03<00:05,  6.04it/s]

C_loss: 2.2908859252929688

 40%|████      | 20/50 [00:03<00:04,  6.04it/s]

C_loss: 2.292912721633911

 42%|████▏     | 21/50 [00:03<00:04,  6.03it/s]

C_loss: 2.06374454498291

 44%|████▍     | 22/50 [00:03<00:04,  6.03it/s]

C_loss: 1.7823636531829834

 46%|████▌     | 23/50 [00:03<00:04,  6.00it/s]

C_loss: 1.523952841758728

 48%|████▊     | 24/50 [00:03<00:04,  6.01it/s]

C_loss: 1.025123953819275

 50%|█████     | 25/50 [00:04<00:04,  5.99it/s]

C_loss: 0.8984978199005127

 52%|█████▏    | 26/50 [00:04<00:04,  5.98it/s]

C_loss: 0.5769519805908203

 54%|█████▍    | 27/50 [00:04<00:03,  6.01it/s]

C_loss: 0.6685912609100342

 56%|█████▌    | 28/50 [00:04<00:03,  6.00it/s]

C_loss: 0.607263445854187

 58%|█████▊    | 29/50 [00:04<00:03,  5.98it/s]

C_loss: 0.4965532720088959

 60%|██████    | 30/50 [00:04<00:03,  5.99it/s]

C_loss: 0.3800179064273834

 62%|██████▏   | 31/50 [00:05<00:03,  5.99it/s]

C_loss: 0.338340163230896

 64%|██████▍   | 32/50 [00:05<00:03,  6.00it/s]

C_loss: 0.4639054536819458

 66%|██████▌   | 33/50 [00:05<00:02,  5.99it/s]

C_loss: 0.3923502564430237

 68%|██████▊   | 34/50 [00:05<00:02,  6.01it/s]

C_loss: 0.3525589108467102

 70%|███████   | 35/50 [00:05<00:02,  6.03it/s]

C_loss: 0.36504918336868286

 72%|███████▏  | 36/50 [00:05<00:02,  6.04it/s]

C_loss: 0.374184787273407

 74%|███████▍  | 37/50 [00:06<00:02,  6.03it/s]

C_loss: 0.3686545193195343

 76%|███████▌  | 38/50 [00:06<00:01,  6.02it/s]

C_loss: 0.42123132944107056

 78%|███████▊  | 39/50 [00:06<00:01,  6.01it/s]

C_loss: 0.5040431022644043

 80%|████████  | 40/50 [00:06<00:01,  6.00it/s]

C_loss: 0.45877599716186523

 82%|████████▏ | 41/50 [00:06<00:01,  6.00it/s]

C_loss: 0.34630537033081055

 84%|████████▍ | 42/50 [00:06<00:01,  5.95it/s]

C_loss: 0.42193323373794556

 86%|████████▌ | 43/50 [00:07<00:01,  5.97it/s]

C_loss: 0.34999510645866394

 88%|████████▊ | 44/50 [00:07<00:01,  5.96it/s]

C_loss: 0.3041464388370514

 90%|█████████ | 45/50 [00:07<00:00,  5.98it/s]

C_loss: 0.36604711413383484

 92%|█████████▏| 46/50 [00:07<00:00,  6.01it/s]

C_loss: 0.39486873149871826

 94%|█████████▍| 47/50 [00:07<00:00,  6.01it/s]

C_loss: 0.3804394602775574

 96%|█████████▌| 48/50 [00:07<00:00,  6.00it/s]

C_loss: 0.3432959020137787

 98%|█████████▊| 49/50 [00:08<00:00,  6.02it/s]

C_loss: 0.4025803208351135

100%|██████████| 50/50 [00:08<00:00,  6.01it/s]


C_loss: 0.3285348117351532

100%|██████████| 20/20 [00:01<00:00, 13.89it/s]


accuracy: 0.8048

  2%|▏         | 1/50 [00:00<00:08,  6.08it/s]

C_loss: 0.32386136054992676

  4%|▍         | 2/50 [00:00<00:08,  5.99it/s]

C_loss: 0.34174299240112305

  6%|▌         | 3/50 [00:00<00:07,  6.01it/s]

C_loss: 0.3791801333427429

  8%|▊         | 4/50 [00:00<00:07,  6.04it/s]

C_loss: 0.38038355112075806

 10%|█         | 5/50 [00:00<00:07,  6.00it/s]

C_loss: 0.32637083530426025

 12%|█▏        | 6/50 [00:00<00:07,  6.02it/s]

C_loss: 0.3995647728443146

 14%|█▍        | 7/50 [00:01<00:07,  5.99it/s]

C_loss: 0.27795374393463135

 16%|█▌        | 8/50 [00:01<00:07,  6.00it/s]

C_loss: 0.34068572521209717

 18%|█▊        | 9/50 [00:01<00:06,  6.01it/s]

C_loss: 0.3237266540527344

 20%|██        | 10/50 [00:01<00:06,  5.95it/s]

C_loss: 0.32265526056289673

 22%|██▏       | 11/50 [00:01<00:06,  5.95it/s]

C_loss: 0.3164564073085785

 24%|██▍       | 12/50 [00:02<00:06,  5.95it/s]

C_loss: 0.30114078521728516

 26%|██▌       | 13/50 [00:02<00:06,  5.96it/s]

C_loss: 0.299046128988266

 28%|██▊       | 14/50 [00:02<00:06,  5.94it/s]

C_loss: 0.30820104479789734

 30%|███       | 15/50 [00:02<00:05,  5.96it/s]

C_loss: 0.24114099144935608

 32%|███▏      | 16/50 [00:02<00:05,  5.96it/s]

C_loss: 0.30015453696250916

 34%|███▍      | 17/50 [00:02<00:05,  5.97it/s]

C_loss: 0.27583014965057373

 36%|███▌      | 18/50 [00:03<00:05,  5.95it/s]

C_loss: 0.31939607858657837

 38%|███▊      | 19/50 [00:03<00:05,  5.98it/s]

C_loss: 0.24837622046470642

 40%|████      | 20/50 [00:03<00:05,  5.98it/s]

C_loss: 0.30484652519226074

 42%|████▏     | 21/50 [00:03<00:04,  6.00it/s]

C_loss: 0.2883237898349762

 44%|████▍     | 22/50 [00:03<00:04,  5.99it/s]

C_loss: 0.3090200424194336

 46%|████▌     | 23/50 [00:03<00:04,  5.98it/s]

C_loss: 0.38078704476356506

 48%|████▊     | 24/50 [00:04<00:04,  5.99it/s]

C_loss: 0.3517204225063324

 50%|█████     | 25/50 [00:04<00:04,  6.00it/s]

C_loss: 0.3429718315601349

 52%|█████▏    | 26/50 [00:04<00:04,  6.00it/s]

C_loss: 0.2940753698348999

 54%|█████▍    | 27/50 [00:04<00:03,  6.01it/s]

C_loss: 0.25442904233932495

 56%|█████▌    | 28/50 [00:04<00:03,  5.96it/s]

C_loss: 0.3110915422439575

 58%|█████▊    | 29/50 [00:04<00:03,  5.96it/s]

C_loss: 0.3491346836090088

 60%|██████    | 30/50 [00:05<00:03,  5.98it/s]

C_loss: 0.2603093087673187

 62%|██████▏   | 31/50 [00:05<00:03,  5.99it/s]

C_loss: 0.27604615688323975

 64%|██████▍   | 32/50 [00:05<00:03,  6.00it/s]

C_loss: 0.3051874041557312

 66%|██████▌   | 33/50 [00:05<00:02,  6.01it/s]

C_loss: 0.3087770342826843

 68%|██████▊   | 34/50 [00:05<00:02,  6.01it/s]

C_loss: 0.2642587423324585

 70%|███████   | 35/50 [00:05<00:02,  6.01it/s]

C_loss: 0.29190096259117126

 72%|███████▏  | 36/50 [00:06<00:02,  6.02it/s]

C_loss: 0.2713605761528015

 74%|███████▍  | 37/50 [00:06<00:02,  6.01it/s]

C_loss: 0.2955860197544098

 76%|███████▌  | 38/50 [00:06<00:01,  6.02it/s]

C_loss: 0.3388064503669739

 78%|███████▊  | 39/50 [00:06<00:01,  6.03it/s]

C_loss: 0.27516621351242065

 80%|████████  | 40/50 [00:06<00:01,  6.03it/s]

C_loss: 0.25528547167778015

 82%|████████▏ | 41/50 [00:06<00:01,  6.01it/s]

C_loss: 0.27207353711128235

 84%|████████▍ | 42/50 [00:07<00:01,  6.01it/s]

C_loss: 0.21454226970672607

 86%|████████▌ | 43/50 [00:07<00:01,  6.01it/s]

C_loss: 0.2411506175994873

 88%|████████▊ | 44/50 [00:07<00:00,  6.00it/s]

C_loss: 0.26071515679359436

 90%|█████████ | 45/50 [00:07<00:00,  6.00it/s]

C_loss: 0.2665778696537018

 92%|█████████▏| 46/50 [00:07<00:00,  6.00it/s]

C_loss: 0.2667895257472992

 94%|█████████▍| 47/50 [00:07<00:00,  6.00it/s]

C_loss: 0.2969384789466858

 96%|█████████▌| 48/50 [00:08<00:00,  6.00it/s]

C_loss: 0.2551073133945465

 98%|█████████▊| 49/50 [00:08<00:00,  5.99it/s]

C_loss: 0.3259851038455963

100%|██████████| 50/50 [00:08<00:00,  5.99it/s]


C_loss: 0.19646066427230835

100%|██████████| 20/20 [00:01<00:00, 13.52it/s]


accuracy: 0.8282

  2%|▏         | 1/50 [00:00<00:08,  6.02it/s]

C_loss: 0.2545713782310486

  4%|▍         | 2/50 [00:00<00:07,  6.02it/s]

C_loss: 0.22566978633403778

  6%|▌         | 3/50 [00:00<00:07,  6.02it/s]

C_loss: 0.29826533794403076

  8%|▊         | 4/50 [00:00<00:07,  5.98it/s]

C_loss: 0.40193235874176025

 10%|█         | 5/50 [00:00<00:07,  6.00it/s]

C_loss: 0.23851370811462402

 12%|█▏        | 6/50 [00:00<00:07,  6.00it/s]

C_loss: 0.2749395966529846

 14%|█▍        | 7/50 [00:01<00:07,  5.99it/s]

C_loss: 0.3116953372955322

 16%|█▌        | 8/50 [00:01<00:06,  6.01it/s]

C_loss: 0.26386600732803345

 18%|█▊        | 9/50 [00:01<00:06,  6.00it/s]

C_loss: 0.2972228527069092

 20%|██        | 10/50 [00:01<00:06,  5.99it/s]

C_loss: 0.27044785022735596

 22%|██▏       | 11/50 [00:01<00:06,  5.99it/s]

C_loss: 0.37967079877853394

 24%|██▍       | 12/50 [00:02<00:06,  5.99it/s]

C_loss: 0.32338300347328186

 26%|██▌       | 13/50 [00:02<00:06,  5.96it/s]

C_loss: 0.2986753284931183

 28%|██▊       | 14/50 [00:02<00:06,  5.94it/s]

C_loss: 0.2958791255950928

 30%|███       | 15/50 [00:02<00:05,  5.94it/s]

C_loss: 0.24468974769115448

 32%|███▏      | 16/50 [00:02<00:05,  5.93it/s]

C_loss: 0.2803521752357483

 34%|███▍      | 17/50 [00:02<00:05,  5.90it/s]

C_loss: 0.2323765754699707

 36%|███▌      | 18/50 [00:03<00:05,  5.93it/s]

C_loss: 0.28387609124183655

 38%|███▊      | 19/50 [00:03<00:05,  5.95it/s]

C_loss: 0.26450878381729126

 40%|████      | 20/50 [00:03<00:05,  5.96it/s]

C_loss: 0.27545684576034546

 42%|████▏     | 21/50 [00:03<00:04,  5.96it/s]

C_loss: 0.3386673629283905

 44%|████▍     | 22/50 [00:03<00:04,  5.96it/s]

C_loss: 0.27145934104919434

 46%|████▌     | 23/50 [00:03<00:04,  5.95it/s]

C_loss: 0.3786420226097107

 48%|████▊     | 24/50 [00:04<00:04,  5.93it/s]

C_loss: 0.30762767791748047

 50%|█████     | 25/50 [00:04<00:04,  5.92it/s]

C_loss: 0.4007576107978821

 52%|█████▏    | 26/50 [00:04<00:04,  5.90it/s]

C_loss: 0.41013434529304504

 54%|█████▍    | 27/50 [00:04<00:03,  5.94it/s]

C_loss: 0.5572273135185242

 56%|█████▌    | 28/50 [00:04<00:03,  5.94it/s]

C_loss: 0.5924159288406372

 58%|█████▊    | 29/50 [00:04<00:03,  5.93it/s]

C_loss: 0.7725938558578491

 60%|██████    | 30/50 [00:05<00:03,  5.89it/s]

C_loss: 0.572585940361023

 62%|██████▏   | 31/50 [00:05<00:03,  5.91it/s]

C_loss: 0.6216835379600525

 64%|██████▍   | 32/50 [00:05<00:03,  5.92it/s]

C_loss: 0.9021425247192383

 66%|██████▌   | 33/50 [00:05<00:02,  5.91it/s]

C_loss: 0.4080304503440857

 68%|██████▊   | 34/50 [00:05<00:02,  5.92it/s]

C_loss: 0.5585464239120483

 70%|███████   | 35/50 [00:05<00:02,  5.94it/s]

C_loss: 0.4590781331062317

 72%|███████▏  | 36/50 [00:06<00:02,  5.91it/s]

C_loss: 0.4240848422050476

 74%|███████▍  | 37/50 [00:06<00:02,  5.94it/s]

C_loss: 0.38738855719566345

 76%|███████▌  | 38/50 [00:06<00:02,  5.96it/s]

C_loss: 0.40592139959335327

 78%|███████▊  | 39/50 [00:06<00:01,  5.97it/s]

C_loss: 0.3475346267223358

 80%|████████  | 40/50 [00:06<00:01,  5.94it/s]

C_loss: 0.355556845664978

 82%|████████▏ | 41/50 [00:06<00:01,  5.93it/s]

C_loss: 0.4124458134174347

 84%|████████▍ | 42/50 [00:07<00:01,  5.92it/s]

C_loss: 0.43349647521972656

 86%|████████▌ | 43/50 [00:07<00:01,  5.94it/s]

C_loss: 0.31865039467811584

 88%|████████▊ | 44/50 [00:07<00:01,  5.96it/s]

C_loss: 0.29174596071243286

 90%|█████████ | 45/50 [00:07<00:00,  5.94it/s]

C_loss: 0.2193869799375534

 92%|█████████▏| 46/50 [00:07<00:00,  5.96it/s]

C_loss: 0.2762129306793213

 94%|█████████▍| 47/50 [00:07<00:00,  5.96it/s]

C_loss: 0.22233210504055023

 96%|█████████▌| 48/50 [00:08<00:00,  5.97it/s]

C_loss: 0.29208239912986755

 98%|█████████▊| 49/50 [00:08<00:00,  5.94it/s]

C_loss: 0.23802515864372253

100%|██████████| 50/50 [00:08<00:00,  5.95it/s]


C_loss: 0.26181477308273315

100%|██████████| 20/20 [00:01<00:00, 13.64it/s]


accuracy: 0.8282

  2%|▏         | 1/50 [00:00<00:08,  6.06it/s]

C_loss: 0.2977971136569977

  4%|▍         | 2/50 [00:00<00:08,  5.95it/s]

C_loss: 0.24201488494873047

  6%|▌         | 3/50 [00:00<00:07,  5.95it/s]

C_loss: 0.20945052802562714

  8%|▊         | 4/50 [00:00<00:07,  5.95it/s]

C_loss: 0.26615384221076965

 10%|█         | 5/50 [00:00<00:07,  5.94it/s]

C_loss: 0.24913290143013

 12%|█▏        | 6/50 [00:01<00:07,  5.95it/s]

C_loss: 0.2581791281700134

 14%|█▍        | 7/50 [00:01<00:07,  5.97it/s]

C_loss: 0.2554109990596771

 16%|█▌        | 8/50 [00:01<00:07,  5.97it/s]

C_loss: 0.2463403344154358

 18%|█▊        | 9/50 [00:01<00:06,  5.97it/s]

C_loss: 0.24755924940109253

 20%|██        | 10/50 [00:01<00:06,  5.98it/s]

C_loss: 0.21609926223754883

 22%|██▏       | 11/50 [00:01<00:06,  5.96it/s]

C_loss: 0.2203262895345688

 24%|██▍       | 12/50 [00:02<00:06,  5.94it/s]

C_loss: 0.22961965203285217

 26%|██▌       | 13/50 [00:02<00:06,  5.94it/s]

C_loss: 0.1743423342704773

 28%|██▊       | 14/50 [00:02<00:06,  5.95it/s]

C_loss: 0.1813933551311493

 30%|███       | 15/50 [00:02<00:05,  5.92it/s]

C_loss: 0.22066240012645721

 32%|███▏      | 16/50 [00:02<00:05,  5.91it/s]

C_loss: 0.24072331190109253

 34%|███▍      | 17/50 [00:02<00:05,  5.93it/s]

C_loss: 0.1732710301876068

 36%|███▌      | 18/50 [00:03<00:05,  5.94it/s]

C_loss: 0.16565300524234772

 38%|███▊      | 19/50 [00:03<00:05,  5.95it/s]

C_loss: 0.2184397578239441

 40%|████      | 20/50 [00:03<00:05,  5.92it/s]

C_loss: 0.2217469960451126

 42%|████▏     | 21/50 [00:03<00:04,  5.91it/s]

C_loss: 0.19798801839351654

 44%|████▍     | 22/50 [00:03<00:04,  5.92it/s]

C_loss: 0.1538369208574295

 46%|████▌     | 23/50 [00:03<00:04,  5.94it/s]

C_loss: 0.1930154412984848

 48%|████▊     | 24/50 [00:04<00:04,  5.93it/s]

C_loss: 0.18076354265213013

 50%|█████     | 25/50 [00:04<00:04,  5.93it/s]

C_loss: 0.2068423330783844

 52%|█████▏    | 26/50 [00:04<00:04,  5.94it/s]

C_loss: 0.17818781733512878

 54%|█████▍    | 27/50 [00:04<00:03,  5.95it/s]

C_loss: 0.23549017310142517

 56%|█████▌    | 28/50 [00:04<00:03,  5.95it/s]

C_loss: 0.1673617959022522

 58%|█████▊    | 29/50 [00:04<00:03,  5.94it/s]

C_loss: 0.19402623176574707

 60%|██████    | 30/50 [00:05<00:03,  5.95it/s]

C_loss: 0.19938379526138306

 62%|██████▏   | 31/50 [00:05<00:03,  5.96it/s]

C_loss: 0.21298742294311523

 64%|██████▍   | 32/50 [00:05<00:03,  5.95it/s]

C_loss: 0.22372028231620789

 66%|██████▌   | 33/50 [00:05<00:02,  5.96it/s]

C_loss: 0.17723479866981506

 68%|██████▊   | 34/50 [00:05<00:02,  5.95it/s]

C_loss: 0.19402523338794708

 70%|███████   | 35/50 [00:05<00:02,  5.94it/s]

C_loss: 0.21914833784103394

 72%|███████▏  | 36/50 [00:06<00:02,  5.96it/s]

C_loss: 0.23971036076545715

 74%|███████▍  | 37/50 [00:06<00:02,  5.93it/s]

C_loss: 0.2481136918067932

 76%|███████▌  | 38/50 [00:06<00:02,  5.95it/s]

C_loss: 0.2591332793235779

 78%|███████▊  | 39/50 [00:06<00:01,  5.96it/s]

C_loss: 0.25165092945098877

 80%|████████  | 40/50 [00:06<00:01,  5.97it/s]

C_loss: 0.2338385134935379

 82%|████████▏ | 41/50 [00:06<00:01,  5.94it/s]

C_loss: 0.29889023303985596

 84%|████████▍ | 42/50 [00:07<00:01,  5.94it/s]

C_loss: 0.19307471811771393

 86%|████████▌ | 43/50 [00:07<00:01,  5.92it/s]

C_loss: 0.2096220999956131

 88%|████████▊ | 44/50 [00:07<00:01,  5.92it/s]

C_loss: 0.19338594377040863

 90%|█████████ | 45/50 [00:07<00:00,  5.94it/s]

C_loss: 0.22730287909507751

 92%|█████████▏| 46/50 [00:07<00:00,  5.95it/s]

C_loss: 0.18228140473365784

 94%|█████████▍| 47/50 [00:07<00:00,  5.96it/s]

C_loss: 0.20743992924690247

 96%|█████████▌| 48/50 [00:08<00:00,  5.97it/s]

C_loss: 0.21178072690963745

 98%|█████████▊| 49/50 [00:08<00:00,  5.93it/s]

C_loss: 0.22627021372318268

100%|██████████| 50/50 [00:08<00:00,  5.95it/s]


C_loss: 0.22548463940620422

100%|██████████| 20/20 [00:01<00:00, 13.28it/s]


accuracy: 0.8282

  2%|▏         | 1/50 [00:00<00:08,  5.93it/s]

C_loss: 0.22173357009887695

  4%|▍         | 2/50 [00:00<00:08,  5.96it/s]

C_loss: 0.22434410452842712

  6%|▌         | 3/50 [00:00<00:07,  5.98it/s]

C_loss: 0.1918865144252777

  8%|▊         | 4/50 [00:00<00:07,  5.99it/s]

C_loss: 0.17582149803638458

 10%|█         | 5/50 [00:00<00:07,  5.97it/s]

C_loss: 0.1854715198278427

 12%|█▏        | 6/50 [00:01<00:07,  5.95it/s]

C_loss: 0.18243390321731567

 14%|█▍        | 7/50 [00:01<00:07,  5.88it/s]

C_loss: 0.2138960212469101

 16%|█▌        | 8/50 [00:01<00:07,  5.85it/s]

C_loss: 0.2193809151649475

 18%|█▊        | 9/50 [00:01<00:06,  5.89it/s]

C_loss: 0.20245619118213654

 20%|██        | 10/50 [00:01<00:06,  5.91it/s]

C_loss: 0.2245023399591446

 22%|██▏       | 11/50 [00:01<00:06,  5.91it/s]

C_loss: 0.18710777163505554

 24%|██▍       | 12/50 [00:02<00:06,  5.90it/s]

C_loss: 0.24334213137626648

 26%|██▌       | 13/50 [00:02<00:06,  5.89it/s]

C_loss: 0.1904798150062561

 28%|██▊       | 14/50 [00:02<00:06,  5.89it/s]

C_loss: 0.18392761051654816

 30%|███       | 15/50 [00:02<00:05,  5.89it/s]

C_loss: 0.2706625461578369

 32%|███▏      | 16/50 [00:02<00:05,  5.91it/s]

C_loss: 0.1889614462852478

 34%|███▍      | 17/50 [00:02<00:05,  5.89it/s]

C_loss: 0.2369501292705536

 36%|███▌      | 18/50 [00:03<00:05,  5.88it/s]

C_loss: 0.21261121332645416

 38%|███▊      | 19/50 [00:03<00:05,  5.91it/s]

C_loss: 0.20081551373004913

 40%|████      | 20/50 [00:03<00:05,  5.92it/s]

C_loss: 0.19535213708877563

 42%|████▏     | 21/50 [00:03<00:04,  5.95it/s]

C_loss: 0.20574435591697693

 44%|████▍     | 22/50 [00:03<00:04,  5.95it/s]

C_loss: 0.19513046741485596

 46%|████▌     | 23/50 [00:03<00:04,  5.96it/s]

C_loss: 0.1868058443069458

 48%|████▊     | 24/50 [00:04<00:04,  5.93it/s]

C_loss: 0.20985358953475952

 50%|█████     | 25/50 [00:04<00:04,  5.95it/s]

C_loss: 0.21000540256500244

 52%|█████▏    | 26/50 [00:04<00:04,  5.97it/s]

C_loss: 0.23799748718738556

 54%|█████▍    | 27/50 [00:04<00:03,  5.99it/s]

C_loss: 0.21119405329227448

 56%|█████▌    | 28/50 [00:04<00:03,  6.00it/s]

C_loss: 0.20730352401733398

 58%|█████▊    | 29/50 [00:04<00:03,  6.01it/s]

C_loss: 0.21109062433242798

 60%|██████    | 30/50 [00:05<00:03,  6.00it/s]

C_loss: 0.2025226503610611

 62%|██████▏   | 31/50 [00:05<00:03,  5.99it/s]

C_loss: 0.27507904171943665

 64%|██████▍   | 32/50 [00:05<00:03,  5.98it/s]

C_loss: 0.2349417507648468

 66%|██████▌   | 33/50 [00:05<00:02,  5.97it/s]

C_loss: 0.22635138034820557

 68%|██████▊   | 34/50 [00:05<00:02,  5.96it/s]

C_loss: 0.24856078624725342

 70%|███████   | 35/50 [00:05<00:02,  5.95it/s]

C_loss: 0.24103742837905884

 72%|███████▏  | 36/50 [00:06<00:02,  5.92it/s]

C_loss: 0.2395288050174713

 74%|███████▍  | 37/50 [00:06<00:02,  5.93it/s]

C_loss: 0.2737348675727844

 76%|███████▌  | 38/50 [00:06<00:02,  5.94it/s]

C_loss: 0.2014612853527069

 78%|███████▊  | 39/50 [00:06<00:01,  5.94it/s]

C_loss: 0.23052899539470673

 80%|████████  | 40/50 [00:06<00:01,  5.92it/s]

C_loss: 0.19742608070373535

 82%|████████▏ | 41/50 [00:06<00:01,  5.94it/s]

C_loss: 0.2962310314178467

 84%|████████▍ | 42/50 [00:07<00:01,  5.94it/s]

C_loss: 0.22279508411884308

 86%|████████▌ | 43/50 [00:07<00:01,  5.94it/s]

C_loss: 0.20247703790664673

 88%|████████▊ | 44/50 [00:07<00:01,  5.94it/s]

C_loss: 0.19322699308395386

 90%|█████████ | 45/50 [00:07<00:00,  5.94it/s]

C_loss: 0.2449449598789215

 92%|█████████▏| 46/50 [00:07<00:00,  5.93it/s]

C_loss: 0.20631571114063263

 94%|█████████▍| 47/50 [00:07<00:00,  5.92it/s]

C_loss: 0.2070004940032959

 96%|█████████▌| 48/50 [00:08<00:00,  5.90it/s]

C_loss: 0.24347840249538422

 98%|█████████▊| 49/50 [00:08<00:00,  5.87it/s]

C_loss: 0.21024134755134583

100%|██████████| 50/50 [00:08<00:00,  5.93it/s]


C_loss: 0.21425479650497437

100%|██████████| 20/20 [00:01<00:00, 13.72it/s]


accuracy: 0.8412

  2%|▏         | 1/50 [00:00<00:08,  6.00it/s]

C_loss: 0.22259172797203064

  4%|▍         | 2/50 [00:00<00:08,  5.97it/s]

C_loss: 0.20424538850784302

  6%|▌         | 3/50 [00:00<00:07,  5.92it/s]

C_loss: 0.19062870740890503

  8%|▊         | 4/50 [00:00<00:07,  5.96it/s]

C_loss: 0.1920461356639862

 10%|█         | 5/50 [00:00<00:07,  5.98it/s]

C_loss: 0.19626283645629883

 12%|█▏        | 6/50 [00:01<00:07,  5.99it/s]

C_loss: 0.19474004209041595

 14%|█▍        | 7/50 [00:01<00:07,  5.99it/s]

C_loss: 0.1985398828983307

 16%|█▌        | 8/50 [00:01<00:07,  6.00it/s]

C_loss: 0.21340858936309814

 18%|█▊        | 9/50 [00:01<00:06,  6.01it/s]

C_loss: 0.23450994491577148

 20%|██        | 10/50 [00:01<00:06,  6.01it/s]

C_loss: 0.1824333220720291

 22%|██▏       | 11/50 [00:01<00:06,  6.00it/s]

C_loss: 0.2049713134765625

 24%|██▍       | 12/50 [00:02<00:06,  5.95it/s]

C_loss: 0.18456920981407166

 26%|██▌       | 13/50 [00:02<00:06,  5.95it/s]

C_loss: 0.22528764605522156

 28%|██▊       | 14/50 [00:02<00:06,  5.96it/s]

C_loss: 0.22234958410263062

 30%|███       | 15/50 [00:02<00:05,  5.98it/s]

C_loss: 0.21655692160129547

 32%|███▏      | 16/50 [00:02<00:05,  5.95it/s]

C_loss: 0.19235746562480927

 34%|███▍      | 17/50 [00:02<00:05,  5.98it/s]

C_loss: 0.1791350245475769

 36%|███▌      | 18/50 [00:03<00:05,  5.99it/s]

C_loss: 0.16543447971343994

 38%|███▊      | 19/50 [00:03<00:05,  6.01it/s]

C_loss: 0.19623255729675293

 40%|████      | 20/50 [00:03<00:04,  6.00it/s]

C_loss: 0.24097073078155518

 42%|████▏     | 21/50 [00:03<00:04,  6.00it/s]

C_loss: 0.2113521695137024

 44%|████▍     | 22/50 [00:03<00:04,  5.98it/s]

C_loss: 0.21796704828739166

 46%|████▌     | 23/50 [00:03<00:04,  6.00it/s]

C_loss: 0.17998385429382324

 48%|████▊     | 24/50 [00:04<00:04,  5.95it/s]

C_loss: 0.18566063046455383

 50%|█████     | 25/50 [00:04<00:04,  5.98it/s]

C_loss: 0.25662174820899963

 52%|█████▏    | 26/50 [00:04<00:04,  5.99it/s]

C_loss: 0.2150358110666275

 54%|█████▍    | 27/50 [00:04<00:03,  6.01it/s]

C_loss: 0.20328205823898315

 56%|█████▌    | 28/50 [00:04<00:03,  6.00it/s]

C_loss: 0.1987198293209076

 58%|█████▊    | 29/50 [00:04<00:03,  6.00it/s]

C_loss: 0.18031972646713257

 60%|██████    | 30/50 [00:05<00:03,  6.03it/s]

C_loss: 0.18629109859466553

 62%|██████▏   | 31/50 [00:05<00:03,  6.02it/s]

C_loss: 0.23267273604869843

 64%|██████▍   | 32/50 [00:05<00:02,  6.01it/s]

C_loss: 0.1583355963230133

 66%|██████▌   | 33/50 [00:05<00:02,  6.03it/s]

C_loss: 0.17773953080177307

 68%|██████▊   | 34/50 [00:05<00:02,  6.01it/s]

C_loss: 0.2066144198179245

 70%|███████   | 35/50 [00:05<00:02,  6.01it/s]

C_loss: 0.17319133877754211

 72%|███████▏  | 36/50 [00:06<00:02,  6.01it/s]

C_loss: 0.19133923947811127

 74%|███████▍  | 37/50 [00:06<00:02,  6.03it/s]

C_loss: 0.17730146646499634

 76%|███████▌  | 38/50 [00:06<00:01,  6.03it/s]

C_loss: 0.17044763267040253

 78%|███████▊  | 39/50 [00:06<00:01,  6.03it/s]

C_loss: 0.178232803940773

 80%|████████  | 40/50 [00:06<00:01,  6.02it/s]

C_loss: 0.16738517582416534

 82%|████████▏ | 41/50 [00:06<00:01,  6.03it/s]

C_loss: 0.16007991135120392

 84%|████████▍ | 42/50 [00:07<00:01,  6.01it/s]

C_loss: 0.16203004121780396

 86%|████████▌ | 43/50 [00:07<00:01,  6.02it/s]

C_loss: 0.16287195682525635

 88%|████████▊ | 44/50 [00:07<00:00,  6.03it/s]

C_loss: 0.20425322651863098

 90%|█████████ | 45/50 [00:07<00:00,  5.99it/s]

C_loss: 0.14188556373119354

 92%|█████████▏| 46/50 [00:07<00:00,  5.98it/s]

C_loss: 0.2223975658416748

 94%|█████████▍| 47/50 [00:07<00:00,  6.00it/s]

C_loss: 0.1837838739156723

 96%|█████████▌| 48/50 [00:08<00:00,  6.01it/s]

C_loss: 0.20686307549476624

 98%|█████████▊| 49/50 [00:08<00:00,  6.02it/s]

C_loss: 0.23432181775569916

100%|██████████| 50/50 [00:08<00:00,  6.00it/s]


C_loss: 0.16752605140209198

100%|██████████| 20/20 [00:01<00:00, 13.70it/s]


accuracy: 0.8412

  2%|▏         | 1/50 [00:00<00:08,  5.92it/s]

C_loss: 0.18382540345191956

  4%|▍         | 2/50 [00:00<00:08,  5.98it/s]

C_loss: 0.19172491133213043

  6%|▌         | 3/50 [00:00<00:07,  6.01it/s]

C_loss: 0.19350528717041016

  8%|▊         | 4/50 [00:00<00:07,  6.03it/s]

C_loss: 0.20803891122341156

 10%|█         | 5/50 [00:00<00:07,  5.99it/s]

C_loss: 0.17654122412204742

 12%|█▏        | 6/50 [00:01<00:07,  5.99it/s]

C_loss: 0.19433194398880005

 14%|█▍        | 7/50 [00:01<00:07,  6.00it/s]

C_loss: 0.19696268439292908

 16%|█▌        | 8/50 [00:01<00:07,  5.97it/s]

C_loss: 0.2116144448518753

 18%|█▊        | 9/50 [00:01<00:06,  5.96it/s]

C_loss: 0.17914217710494995

 20%|██        | 10/50 [00:01<00:06,  5.97it/s]

C_loss: 0.17964260280132294

 22%|██▏       | 11/50 [00:01<00:06,  5.97it/s]

C_loss: 0.2610635459423065

 24%|██▍       | 12/50 [00:02<00:06,  5.98it/s]

C_loss: 0.23997578024864197

 26%|██▌       | 13/50 [00:02<00:06,  5.99it/s]

C_loss: 0.19438385963439941

 28%|██▊       | 14/50 [00:02<00:06,  5.98it/s]

C_loss: 0.17887960374355316

 30%|███       | 15/50 [00:02<00:05,  5.99it/s]

C_loss: 0.19047516584396362

 32%|███▏      | 16/50 [00:02<00:05,  5.99it/s]

C_loss: 0.19304868578910828

 34%|███▍      | 17/50 [00:02<00:05,  5.99it/s]

C_loss: 0.19638586044311523

 36%|███▌      | 18/50 [00:03<00:05,  5.99it/s]

C_loss: 0.1615372747182846

 38%|███▊      | 19/50 [00:03<00:05,  6.00it/s]

C_loss: 0.21979209780693054

 40%|████      | 20/50 [00:03<00:04,  6.01it/s]

C_loss: 0.19478052854537964

 42%|████▏     | 21/50 [00:03<00:04,  6.00it/s]

C_loss: 0.2274286299943924

 44%|████▍     | 22/50 [00:03<00:04,  6.01it/s]

C_loss: 0.18063342571258545

 46%|████▌     | 23/50 [00:03<00:04,  6.00it/s]

C_loss: 0.19626954197883606

 48%|████▊     | 24/50 [00:04<00:04,  5.96it/s]

C_loss: 0.20229601860046387

 50%|█████     | 25/50 [00:04<00:04,  5.97it/s]

C_loss: 0.19839538633823395

 52%|█████▏    | 26/50 [00:04<00:04,  5.96it/s]

C_loss: 0.16587623953819275

 54%|█████▍    | 27/50 [00:04<00:03,  5.92it/s]

C_loss: 0.18447265028953552

 56%|█████▌    | 28/50 [00:04<00:03,  5.92it/s]

C_loss: 0.18670377135276794

 58%|█████▊    | 29/50 [00:04<00:03,  5.95it/s]

C_loss: 0.19949400424957275

 60%|██████    | 30/50 [00:05<00:03,  5.92it/s]

C_loss: 0.18126696348190308

 62%|██████▏   | 31/50 [00:05<00:03,  5.95it/s]

C_loss: 0.17605918645858765

 64%|██████▍   | 32/50 [00:05<00:03,  5.97it/s]

C_loss: 0.16779913008213043

 66%|██████▌   | 33/50 [00:05<00:02,  5.98it/s]

C_loss: 0.19477270543575287

 68%|██████▊   | 34/50 [00:05<00:02,  5.99it/s]

C_loss: 0.20450708270072937

 70%|███████   | 35/50 [00:05<00:02,  5.92it/s]

C_loss: 0.1891305148601532

 72%|███████▏  | 36/50 [00:06<00:02,  5.95it/s]

C_loss: 0.16529789566993713

 74%|███████▍  | 37/50 [00:06<00:02,  5.96it/s]

C_loss: 0.20636782050132751

 76%|███████▌  | 38/50 [00:06<00:02,  5.97it/s]

C_loss: 0.17712147533893585

 78%|███████▊  | 39/50 [00:06<00:01,  5.96it/s]

C_loss: 0.17768818140029907

 80%|████████  | 40/50 [00:06<00:01,  5.98it/s]

C_loss: 0.19540467858314514

 82%|████████▏ | 41/50 [00:06<00:01,  6.00it/s]

C_loss: 0.188749298453331

 84%|████████▍ | 42/50 [00:07<00:01,  6.00it/s]

C_loss: 0.22305703163146973

 86%|████████▌ | 43/50 [00:07<00:01,  6.01it/s]

C_loss: 0.18121568858623505

 88%|████████▊ | 44/50 [00:07<00:00,  6.01it/s]

C_loss: 0.20386123657226562

 90%|█████████ | 45/50 [00:07<00:00,  6.01it/s]

C_loss: 0.24364136159420013

 92%|█████████▏| 46/50 [00:07<00:00,  6.00it/s]

C_loss: 0.23804634809494019

 94%|█████████▍| 47/50 [00:07<00:00,  6.01it/s]

C_loss: 0.1731848120689392

 96%|█████████▌| 48/50 [00:08<00:00,  6.04it/s]

C_loss: 0.16726675629615784

 98%|█████████▊| 49/50 [00:08<00:00,  6.00it/s]

C_loss: 0.18544310331344604

100%|██████████| 50/50 [00:08<00:00,  5.98it/s]


C_loss: 0.1838194727897644

100%|██████████| 20/20 [00:01<00:00, 13.72it/s]


accuracy: 0.8591

  2%|▏         | 1/50 [00:00<00:08,  6.02it/s]

C_loss: 0.19190233945846558

  4%|▍         | 2/50 [00:00<00:08,  5.93it/s]

C_loss: 0.17142775654792786

  6%|▌         | 3/50 [00:00<00:07,  5.98it/s]

C_loss: 0.19477435946464539

  8%|▊         | 4/50 [00:00<00:07,  5.99it/s]

C_loss: 0.20312632620334625

 10%|█         | 5/50 [00:00<00:07,  6.00it/s]

C_loss: 0.22211116552352905

 12%|█▏        | 6/50 [00:01<00:07,  6.01it/s]

C_loss: 0.21113571524620056

 14%|█▍        | 7/50 [00:01<00:07,  6.00it/s]

C_loss: 0.21977485716342926

 16%|█▌        | 8/50 [00:01<00:06,  6.00it/s]

C_loss: 0.1900046318769455

 18%|█▊        | 9/50 [00:01<00:06,  6.00it/s]

C_loss: 0.1889057457447052

 20%|██        | 10/50 [00:01<00:06,  5.95it/s]

C_loss: 0.2226237803697586

 22%|██▏       | 11/50 [00:01<00:06,  5.96it/s]

C_loss: 0.1666371077299118

 24%|██▍       | 12/50 [00:02<00:06,  5.97it/s]

C_loss: 0.17084799706935883

 26%|██▌       | 13/50 [00:02<00:06,  5.97it/s]

C_loss: 0.1999550759792328

 28%|██▊       | 14/50 [00:02<00:06,  5.96it/s]

C_loss: 0.2002057433128357

 30%|███       | 15/50 [00:02<00:05,  5.93it/s]

C_loss: 0.1628427654504776

 32%|███▏      | 16/50 [00:02<00:05,  5.91it/s]

C_loss: 0.18298886716365814

 34%|███▍      | 17/50 [00:02<00:05,  5.93it/s]

C_loss: 0.1754070520401001

 36%|███▌      | 18/50 [00:03<00:05,  5.94it/s]

C_loss: 0.1752755492925644

 38%|███▊      | 19/50 [00:03<00:05,  5.96it/s]

C_loss: 0.18997260928153992

 40%|████      | 20/50 [00:03<00:05,  5.96it/s]

C_loss: 0.1940949708223343

 42%|████▏     | 21/50 [00:03<00:04,  5.97it/s]

C_loss: 0.15482158958911896

 44%|████▍     | 22/50 [00:03<00:04,  5.97it/s]

C_loss: 0.18254566192626953

 46%|████▌     | 23/50 [00:03<00:04,  5.97it/s]

C_loss: 0.19461765885353088

 48%|████▊     | 24/50 [00:04<00:04,  5.99it/s]

C_loss: 0.16697844862937927

 50%|█████     | 25/50 [00:04<00:04,  5.99it/s]

C_loss: 0.18286176025867462

 52%|█████▏    | 26/50 [00:04<00:04,  5.99it/s]

C_loss: 0.1642138510942459

 54%|█████▍    | 27/50 [00:04<00:03,  6.01it/s]

C_loss: 0.16763800382614136

 56%|█████▌    | 28/50 [00:04<00:03,  6.01it/s]

C_loss: 0.21771571040153503

 58%|█████▊    | 29/50 [00:04<00:03,  6.02it/s]

C_loss: 0.1613132655620575

 60%|██████    | 30/50 [00:05<00:03,  6.00it/s]

C_loss: 0.16814759373664856

 62%|██████▏   | 31/50 [00:05<00:03,  5.99it/s]

C_loss: 0.18665170669555664

 64%|██████▍   | 32/50 [00:05<00:02,  6.00it/s]

C_loss: 0.16732051968574524

 66%|██████▌   | 33/50 [00:05<00:02,  6.02it/s]

C_loss: 0.17001619935035706

 68%|██████▊   | 34/50 [00:05<00:02,  6.03it/s]

C_loss: 0.16106954216957092

 70%|███████   | 35/50 [00:05<00:02,  5.99it/s]

C_loss: 0.1723548024892807

 72%|███████▏  | 36/50 [00:06<00:02,  5.99it/s]

C_loss: 0.14032621681690216

 74%|███████▍  | 37/50 [00:06<00:02,  6.00it/s]

C_loss: 0.16518282890319824

 76%|███████▌  | 38/50 [00:06<00:02,  5.99it/s]

C_loss: 0.16806761920452118

 78%|███████▊  | 39/50 [00:06<00:01,  6.00it/s]

C_loss: 0.14684726297855377

 80%|████████  | 40/50 [00:06<00:01,  6.00it/s]

C_loss: 0.15308429300785065

 82%|████████▏ | 41/50 [00:06<00:01,  6.00it/s]

C_loss: 0.15306323766708374

 84%|████████▍ | 42/50 [00:07<00:01,  5.96it/s]

C_loss: 0.1546759009361267

 86%|████████▌ | 43/50 [00:07<00:01,  5.97it/s]

C_loss: 0.1595088392496109

 88%|████████▊ | 44/50 [00:07<00:01,  5.97it/s]

C_loss: 0.16615578532218933

 90%|█████████ | 45/50 [00:07<00:00,  5.99it/s]

C_loss: 0.143270805478096

 92%|█████████▏| 46/50 [00:07<00:00,  6.01it/s]

C_loss: 0.18815304338932037

 94%|█████████▍| 47/50 [00:07<00:00,  6.02it/s]

C_loss: 0.1447940468788147

 96%|█████████▌| 48/50 [00:08<00:00,  5.99it/s]

C_loss: 0.19405192136764526

 98%|█████████▊| 49/50 [00:08<00:00,  6.01it/s]

C_loss: 0.13847863674163818

100%|██████████| 50/50 [00:08<00:00,  5.98it/s]


C_loss: 0.13885597884655

100%|██████████| 20/20 [00:01<00:00, 13.75it/s]


accuracy: 0.8619

  2%|▏         | 1/50 [00:00<00:08,  6.01it/s]

C_loss: 0.1241164430975914

  4%|▍         | 2/50 [00:00<00:08,  5.97it/s]

C_loss: 0.16183580458164215

  6%|▌         | 3/50 [00:00<00:07,  5.99it/s]

C_loss: 0.14660245180130005

  8%|▊         | 4/50 [00:00<00:07,  6.01it/s]

C_loss: 0.14126478135585785

 10%|█         | 5/50 [00:00<00:07,  6.02it/s]

C_loss: 0.1421901285648346

 12%|█▏        | 6/50 [00:00<00:07,  6.01it/s]

C_loss: 0.17257742583751678

 14%|█▍        | 7/50 [00:01<00:07,  6.02it/s]

C_loss: 0.15927964448928833

 16%|█▌        | 8/50 [00:01<00:07,  6.00it/s]

C_loss: 0.14067266881465912

 18%|█▊        | 9/50 [00:01<00:06,  6.01it/s]

C_loss: 0.1759181171655655

 20%|██        | 10/50 [00:01<00:06,  6.02it/s]

C_loss: 0.16612312197685242

 22%|██▏       | 11/50 [00:01<00:06,  6.01it/s]

C_loss: 0.16176137328147888

 24%|██▍       | 12/50 [00:01<00:06,  6.01it/s]

C_loss: 0.16479623317718506

 26%|██▌       | 13/50 [00:02<00:06,  5.96it/s]

C_loss: 0.13201601803302765

 28%|██▊       | 14/50 [00:02<00:06,  5.97it/s]

C_loss: 0.14146751165390015

 30%|███       | 15/50 [00:02<00:05,  5.99it/s]

C_loss: 0.1648387908935547

 32%|███▏      | 16/50 [00:02<00:05,  6.00it/s]

C_loss: 0.136887788772583

 34%|███▍      | 17/50 [00:02<00:05,  5.97it/s]

C_loss: 0.1490139365196228

 36%|███▌      | 18/50 [00:03<00:05,  5.97it/s]

C_loss: 0.12485827505588531

 38%|███▊      | 19/50 [00:03<00:05,  5.99it/s]

C_loss: 0.1562192589044571

 40%|████      | 20/50 [00:03<00:04,  6.01it/s]

C_loss: 0.16056853532791138

 42%|████▏     | 21/50 [00:03<00:04,  6.01it/s]

C_loss: 0.13971608877182007

 44%|████▍     | 22/50 [00:03<00:04,  5.98it/s]

C_loss: 0.16085657477378845

 46%|████▌     | 23/50 [00:03<00:04,  5.98it/s]

C_loss: 0.1650434285402298

 48%|████▊     | 24/50 [00:04<00:04,  5.96it/s]

C_loss: 0.11933490633964539

 50%|█████     | 25/50 [00:04<00:04,  5.96it/s]

C_loss: 0.14103728532791138

 52%|█████▏    | 26/50 [00:04<00:04,  5.93it/s]

C_loss: 0.1608002483844757

 54%|█████▍    | 27/50 [00:04<00:03,  5.95it/s]

C_loss: 0.1683078110218048

 56%|█████▌    | 28/50 [00:04<00:03,  5.98it/s]

C_loss: 0.1869133710861206

 58%|█████▊    | 29/50 [00:04<00:03,  6.00it/s]

C_loss: 0.14422467350959778

 60%|██████    | 30/50 [00:05<00:03,  6.01it/s]

C_loss: 0.12394940853118896

 62%|██████▏   | 31/50 [00:05<00:03,  6.01it/s]

C_loss: 0.15717695653438568

 64%|██████▍   | 32/50 [00:05<00:03,  5.95it/s]

C_loss: 0.17042571306228638

 66%|██████▌   | 33/50 [00:05<00:02,  5.98it/s]

C_loss: 0.14235588908195496

 68%|██████▊   | 34/50 [00:05<00:02,  5.98it/s]

C_loss: 0.13101354241371155

 70%|███████   | 35/50 [00:05<00:02,  5.99it/s]

C_loss: 0.15705178678035736

 72%|███████▏  | 36/50 [00:06<00:02,  5.94it/s]

C_loss: 0.12926559150218964

 74%|███████▍  | 37/50 [00:06<00:02,  5.95it/s]

C_loss: 0.1560525894165039

 76%|███████▌  | 38/50 [00:06<00:02,  5.95it/s]

C_loss: 0.13992808759212494

 78%|███████▊  | 39/50 [00:06<00:01,  5.97it/s]

C_loss: 0.1576540321111679

 80%|████████  | 40/50 [00:06<00:01,  5.99it/s]

C_loss: 0.14355173707008362

 82%|████████▏ | 41/50 [00:06<00:01,  5.99it/s]

C_loss: 0.13541653752326965

 84%|████████▍ | 42/50 [00:07<00:01,  5.95it/s]

C_loss: 0.15347623825073242

 86%|████████▌ | 43/50 [00:07<00:01,  5.95it/s]

C_loss: 0.16008073091506958

 88%|████████▊ | 44/50 [00:07<00:01,  5.95it/s]

C_loss: 0.1247185617685318

 90%|█████████ | 45/50 [00:07<00:00,  5.98it/s]

C_loss: 0.13032187521457672

 92%|█████████▏| 46/50 [00:07<00:00,  5.95it/s]

C_loss: 0.14559820294380188

 94%|█████████▍| 47/50 [00:07<00:00,  5.98it/s]

C_loss: 0.14804625511169434

 96%|█████████▌| 48/50 [00:08<00:00,  5.98it/s]

C_loss: 0.17435961961746216

 98%|█████████▊| 49/50 [00:08<00:00,  6.00it/s]

C_loss: 0.13765373826026917

100%|██████████| 50/50 [00:08<00:00,  5.98it/s]


C_loss: 0.15446464717388153

100%|██████████| 20/20 [00:01<00:00, 13.71it/s]

accuracy: 0.8655




In [23]:
corrected = 0
for b in tqdm(test_dl):
    images, y = b
    outs = mGAN.classifier.forward(images)
    outs = torch.argmax(outs, dim=1)
    corrected += (outs == y).sum().item()

100%|██████████| 20/20 [00:01<00:00, 11.63it/s]


In [24]:
print(f"{i}: {corrected / test_dl.num_data()}")

4: 0.9553
