In [1]:
import numpy as np
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from itertools import product
from sklearn.model_selection import train_test_split
from pathlib import Path
import random
from tqdm import tqdm
import xml.etree.ElementTree as ET
import cv2
import plotly.graph_objs as go
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
import json
from torch.utils.data import Dataset
import torch
import cv2
import torch.nn as nn
import torch.nn.functional as F

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device


im_size = (252, 252)


# dataset definition
class thyroidDataset(Dataset):
    def __init__(self, split):
        self.all_data = []
        self.compositions = {'Unknown':0, 'cystic':1,
                             'predominantly solid':2,
                             'solid':3, 'spongiform appareance':4}
        self.echogenicities = {'Unknown':0, 'hyperechogenecity':1,
                             'hypoechogenecity':2, 'isoechogenicity':3,
                             'marked hypoechogenecity':4}
        self.margins = {'Unknown':0, 'ill- defined':1, 'microlobulated':2,
                        'spiculated':3, 'well defined smooth':4}
        self.calcifications = {'macrocalcification':0, 'microcalcification':1, 'non':2}
        self.types ={'benign':0, 'malign':1}
        self.types_count = []
        for t_type in ['benign', 'malign']:
            root_dir=Path('../data/' + split + '/' + t_type).expanduser().resolve().absolute() 
            print(root_dir)
            files = list(root_dir.glob("*"))
            labels = [self.types[t_type]] * len(files)
            self.types_count.append(len(files))
            data_list = list(zip(files, labels))
            self.all_data.extend(data_list)
        random.shuffle(self.all_data)
        self.cases, self.types = zip(*self.all_data)
        print("number of data items:" + str(len(self.cases)))
        self.sample_weights = [1/self.types_count[label] for label in self.types]
    def __len__(self):
        return len(self.cases)
  
    def __getitem__(self, idx):
        labels = np.zeros(15, dtype = float)
        xml_data = ET.parse(list(self.cases[idx].glob('*[0-9].xml'))[0]).getroot()
        for x in xml_data:
            if x.tag=='composition' and x.text is not None:
                composition = x.text
                labels[self.compositions[composition] - 1] = 1.0
            if x.tag=='echogenicity' and x.text is not None:
                echogenicity = x.text
                labels[self.echogenicities[echogenicity] + 3] = 1.0
            if x.tag=='margins' and x.text is not None:
                margin = x.text
                labels[self.margins[margin] + 7] = 1.0
            if x.tag=='calcifications' and x.text is not None:
                calcification = x.text
                labels[self.calcifications[calcification] + 11] = 1.0
        xml_data = ET.parse(list(self.cases[idx].glob('*[0-9].xml'))[0]).find("mark")
        for x in xml_data:
            if(x.tag=='svg'):
                encoded = str(x.text)
                poly_data = json.loads(x.text)
        
        #labels[15] = list(self.types)[idx]
        im_name = list(self.cases[idx].glob('*[0-9].jpg'))[0]
        im = cv2.imread(str(im_name))[:, :, 0]
        mask = np.zeros(np.shape(im))
        im = cv2.resize(im, dsize=im_size, interpolation=cv2.INTER_CUBIC)
        
        # add mask 
        for polygon in poly_data:
            xs = []
            ys = []
            for point in polygon["points"]:
                xs.append(point["x"])
                ys.append(point["y"])
            contour = np.concatenate((np.expand_dims(xs, 1), np.expand_dims(ys, 1)), axis=1)
            cv2.fillPoly(mask, pts = [contour], color =(1, 1, 1))
        
        #mask = cv2.resize(mask, dsize=(300, 300), interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, dsize=im_size, interpolation=cv2.INTER_LINEAR)
        
        
        
        # Adding data augmentation to avoid overfitting
        if random.randint(1, 10) > 5:
            im = np.flipud(im)
        if random.randint(1, 10) > 5:
            im = np.fliplr(im)
        if random.randint(1, 10) > 5:
            for i in range(random.randint(1, 4)):
                im = np.rot90(im)
        im = np.ascontiguousarray(im)

        #plt.figure()
        #plt.imshow(im)

        transforms = Compose([ToTensor()])
        mask = transforms(mask)
        im = transforms(im)
        
        im = im * mask
        
        im = im.type(torch.FloatTensor)
        
        sample = {"image": im, "labels": torch.from_numpy(labels), "types" : self.types[idx], "name": str(im_name)}
        return sample

In [3]:
# Dataset creation
training_set = thyroidDataset(split='train')
parameters_train = {
    "batch_size": 32,
    #"shuffle": True,
}

parameters_test = {
    "batch_size": 1,
    "shuffle": False,
}
training_set = thyroidDataset(split='train')
training_generator = torch.utils.data.DataLoader(training_set, **parameters_train, sampler=torch.utils.data.WeightedRandomSampler(training_set.sample_weights, len(training_set.cases), replacement=True))

training_generator1 = torch.utils.data.DataLoader(training_set, **parameters_test, sampler=torch.utils.data.WeightedRandomSampler(training_set.sample_weights, len(training_set.cases), replacement=True))


testing_set = thyroidDataset(split='test')
testing_generator = torch.utils.data.DataLoader(testing_set, **parameters_test, sampler=torch.utils.data.WeightedRandomSampler(testing_set.sample_weights, len(testing_set.cases), replacement=True))


import torch.distributions

/home/ahana/thyroid/data/train/benign
/home/ahana/thyroid/data/train/malign
number of data items:73
/home/ahana/thyroid/data/train/benign
/home/ahana/thyroid/data/train/malign
number of data items:73
/home/ahana/thyroid/data/test/benign
/home/ahana/thyroid/data/test/malign
number of data items:25


In [4]:
class Encoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        
        self.N = torch.distributions.Normal(0, 1)

        self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda()
        self.kl = 0


        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, 3, stride=2, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, 3, stride=2, padding=0),
            nn.ReLU(True),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 256, 3, stride=2, padding=0),
            nn.ReLU(True)

        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
### Linear section
        self.encoder_fc = nn.Sequential(
            nn.Linear(2304 + 15, fc2_input_dim * 4),
            nn.ReLU(True),
            nn.Linear(4* fc2_input_dim, 2 * fc2_input_dim)
        )
        self.encoder_lin = nn.Sequential(
            nn.Linear(2 * fc2_input_dim, fc2_input_dim),
            nn.ReLU(True),
            nn.Linear(fc2_input_dim, encoded_space_dim)
        )
        self.encoder_lin1 = nn.Sequential(
            nn.Linear(2 * fc2_input_dim, fc2_input_dim),
            nn.ReLU(True),
            nn.Linear(fc2_input_dim, encoded_space_dim)
        )
        
    def forward(self, x, attributes):
        #print("Encoder input: ", np.shape(x))
        x = self.encoder_cnn(x)
        
        x = self.flatten(x)
        #print("Encoder flattened output: ", np.shape(x))
        #x = self.encoder_lin(x)
        x = torch.cat((x, attributes), dim=1)
        x = self.encoder_fc(x.float())
        mu =  self.encoder_lin(x)
        sigma = torch.exp(self.encoder_lin1(x))
        z = mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()

        return z
    
    
    
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim + 15, fc2_input_dim),
            nn.ReLU(True),
            nn.Linear(fc2_input_dim, 3 * 3 * fc2_input_dim),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(fc2_input_dim, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(fc2_input_dim, 128, 3, stride=2, output_padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 3, stride=2, output_padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 3, stride=2, output_padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2,  padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

In [5]:
### Define the loss function
loss_fn = torch.nn.MSELoss()
loss_latent = torch.nn.L1Loss()



### Define an optimizer (both for the encoder and the decoder!)
lr= 0.001

### Set the random seed for reproducible results
torch.manual_seed(0)

### Initialize the two networks
d = 16

#model = Autoencoder(encoded_space_dim=encoded_space_dim)
encoder = Encoder(encoded_space_dim=d,fc2_input_dim=512)
decoder = Decoder(encoded_space_dim=d,fc2_input_dim=512)
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

optim = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=1e-05)

# Check if the GPU is available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

Selected device: cuda


Decoder(
  (decoder_lin): Sequential(
    (0): Linear(in_features=31, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=512, out_features=4608, bias=True)
    (3): ReLU(inplace=True)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(512, 3, 3))
  (decoder_conv): Sequential(
    (0): ConvTranspose2d(512, 128, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2))
    (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
    (10): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=T

In [6]:
### Training function
def train_epoch(encoder, decoder, device, dataloader, loss_fn, optimizer):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    train_loss = []
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for data in training_generator: # with "_" we just ignore the labels (the second element of the dataloader tuple)
        # Move tensor to the proper device
        image_batch = data["image"]
        image_batch = image_batch.to(device)
        
        
        #print("Random data shape", np.shape(random_data), "; data[labels] ", np.shape(data["labels"]))
        latent = data["labels"]
        latent = latent.to(device)
        
        #print(latent)
        
        # Encode data
        encoded_data = encoder(image_batch, latent)
        
        # Decode data
        
        decoder_input = torch.cat((encoded_data, latent), dim = 1)
        
        decoded_data = decoder(decoder_input.float())
        # Evaluate loss

        #print("Encoder.kl : ", encoder.kl)

        d = ((image_batch - decoded_data)**2).sum()
        loss = d + 0.5*encoder.kl

        #loss = loss_fn(decoded_data, image_batch) + 0.2*loss_latent(encoded_data, latent.float())
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print batch loss
        #print('\t partial train loss (single batch): %f' % (loss.data))
        #train_loss.append(loss.detach().cpu().numpy())

    return loss

In [None]:
num_epochs = 100000
diz_loss = {'train_loss':[],'val_loss':[]}
for epoch in range(num_epochs):
   train_loss = train_epoch(encoder,decoder,device,training_generator,loss_fn,optim)
   print('\n EPOCH {}/{} \t train loss {}'.format(epoch + 1, num_epochs,train_loss))
   diz_loss['train_loss'].append(train_loss)



 EPOCH 1/100000 	 train loss 305662.0

 EPOCH 2/100000 	 train loss 130573.4296875

 EPOCH 3/100000 	 train loss 126032.9921875

 EPOCH 4/100000 	 train loss 120709.3671875

 EPOCH 5/100000 	 train loss 116811.390625

 EPOCH 6/100000 	 train loss 109401.40625

 EPOCH 7/100000 	 train loss 106705.609375

 EPOCH 8/100000 	 train loss 106586.9140625

 EPOCH 9/100000 	 train loss 99921.296875

 EPOCH 10/100000 	 train loss 99998.578125

 EPOCH 11/100000 	 train loss 92383.609375

 EPOCH 12/100000 	 train loss 87242.96875

 EPOCH 13/100000 	 train loss 88669.3359375

 EPOCH 14/100000 	 train loss 86757.109375

 EPOCH 15/100000 	 train loss 81917.2421875

 EPOCH 16/100000 	 train loss 79383.828125

 EPOCH 17/100000 	 train loss 78128.9296875

 EPOCH 18/100000 	 train loss 77092.796875

 EPOCH 19/100000 	 train loss 72473.6875

 EPOCH 20/100000 	 train loss 69547.109375

 EPOCH 21/100000 	 train loss 73111.0625

 EPOCH 22/100000 	 train loss 64137.3203125

 EPOCH 23/100000 	 train loss 65556


 EPOCH 179/100000 	 train loss 3221.1494140625

 EPOCH 180/100000 	 train loss 3378.384521484375

 EPOCH 181/100000 	 train loss 3331.985107421875

 EPOCH 182/100000 	 train loss 3219.158203125

 EPOCH 183/100000 	 train loss 2980.71630859375

 EPOCH 184/100000 	 train loss 3229.03759765625

 EPOCH 185/100000 	 train loss 3113.606201171875

 EPOCH 186/100000 	 train loss 3448.608154296875

 EPOCH 187/100000 	 train loss 3157.03564453125

 EPOCH 188/100000 	 train loss 3204.33837890625

 EPOCH 189/100000 	 train loss 3071.444091796875

 EPOCH 190/100000 	 train loss 3047.21142578125

 EPOCH 191/100000 	 train loss 3289.726806640625

 EPOCH 192/100000 	 train loss 2954.08740234375

 EPOCH 193/100000 	 train loss 3259.030517578125

 EPOCH 194/100000 	 train loss 2946.657470703125

 EPOCH 195/100000 	 train loss 2840.1728515625

 EPOCH 196/100000 	 train loss 2969.073974609375

 EPOCH 197/100000 	 train loss 2996.828857421875

 EPOCH 198/100000 	 train loss 3154.97998046875

 EPOCH 199/10


 EPOCH 346/100000 	 train loss 1542.8267822265625

 EPOCH 347/100000 	 train loss 1293.681640625

 EPOCH 348/100000 	 train loss 1305.8082275390625

 EPOCH 349/100000 	 train loss 1327.851806640625

 EPOCH 350/100000 	 train loss 1516.279296875

 EPOCH 351/100000 	 train loss 1584.8648681640625

 EPOCH 352/100000 	 train loss 1401.1053466796875

 EPOCH 353/100000 	 train loss 1351.260009765625

 EPOCH 354/100000 	 train loss 1251.37646484375

 EPOCH 355/100000 	 train loss 1390.8797607421875

 EPOCH 356/100000 	 train loss 1218.57275390625

 EPOCH 357/100000 	 train loss 1595.8619384765625

 EPOCH 358/100000 	 train loss 1200.2235107421875

 EPOCH 359/100000 	 train loss 1539.3154296875

 EPOCH 360/100000 	 train loss 1403.007080078125

 EPOCH 361/100000 	 train loss 1735.9892578125

 EPOCH 362/100000 	 train loss 1281.897705078125

 EPOCH 363/100000 	 train loss 1391.8033447265625

 EPOCH 364/100000 	 train loss 1245.9107666015625

 EPOCH 365/100000 	 train loss 1135.17822265625

 EP


 EPOCH 511/100000 	 train loss 991.3472900390625

 EPOCH 512/100000 	 train loss 711.0184936523438

 EPOCH 513/100000 	 train loss 826.5079345703125

 EPOCH 514/100000 	 train loss 964.9468383789062

 EPOCH 515/100000 	 train loss 1013.9359741210938

 EPOCH 516/100000 	 train loss 864.642333984375

 EPOCH 517/100000 	 train loss 1152.2469482421875

 EPOCH 518/100000 	 train loss 981.7852783203125

 EPOCH 519/100000 	 train loss 1059.8160400390625

 EPOCH 520/100000 	 train loss 861.8546752929688

 EPOCH 521/100000 	 train loss 750.1403198242188

 EPOCH 522/100000 	 train loss 1142.5897216796875

 EPOCH 523/100000 	 train loss 964.1162109375

 EPOCH 524/100000 	 train loss 881.8485107421875

 EPOCH 525/100000 	 train loss 892.5091552734375

 EPOCH 526/100000 	 train loss 875.576416015625

 EPOCH 527/100000 	 train loss 962.9749145507812

 EPOCH 528/100000 	 train loss 753.8853149414062

 EPOCH 529/100000 	 train loss 959.09619140625

 EPOCH 530/100000 	 train loss 880.1771240234375

 E


 EPOCH 677/100000 	 train loss 622.69873046875

 EPOCH 678/100000 	 train loss 557.4697875976562

 EPOCH 679/100000 	 train loss 630.8005981445312

 EPOCH 680/100000 	 train loss 1154.7835693359375

 EPOCH 681/100000 	 train loss 887.5482788085938

 EPOCH 682/100000 	 train loss 551.1986694335938

 EPOCH 683/100000 	 train loss 627.4716186523438

 EPOCH 684/100000 	 train loss 949.8656005859375

 EPOCH 685/100000 	 train loss 704.0037841796875

 EPOCH 686/100000 	 train loss 699.9603271484375

 EPOCH 687/100000 	 train loss 627.5810546875

 EPOCH 688/100000 	 train loss 637.9986572265625

 EPOCH 689/100000 	 train loss 596.1170043945312

 EPOCH 690/100000 	 train loss 824.1650390625

 EPOCH 691/100000 	 train loss 739.658447265625

 EPOCH 692/100000 	 train loss 631.401611328125

 EPOCH 693/100000 	 train loss 722.1480712890625

 EPOCH 694/100000 	 train loss 803.3535766601562

 EPOCH 695/100000 	 train loss 889.340576171875

 EPOCH 696/100000 	 train loss 626.2576293945312

 EPOCH 69


 EPOCH 843/100000 	 train loss 725.885986328125

 EPOCH 844/100000 	 train loss 906.1978149414062

 EPOCH 845/100000 	 train loss 740.3668212890625

 EPOCH 846/100000 	 train loss 648.0362548828125

 EPOCH 847/100000 	 train loss 912.1668701171875

 EPOCH 848/100000 	 train loss 848.0552978515625

 EPOCH 849/100000 	 train loss 852.6176147460938

 EPOCH 850/100000 	 train loss 518.4833374023438

 EPOCH 851/100000 	 train loss 907.088623046875

 EPOCH 852/100000 	 train loss 651.225341796875

 EPOCH 853/100000 	 train loss 766.7077026367188

 EPOCH 854/100000 	 train loss 550.6278076171875

 EPOCH 855/100000 	 train loss 743.7596435546875

 EPOCH 856/100000 	 train loss 522.4439697265625

 EPOCH 857/100000 	 train loss 921.1791381835938

 EPOCH 858/100000 	 train loss 873.92724609375

 EPOCH 859/100000 	 train loss 696.4286499023438

 EPOCH 860/100000 	 train loss 487.76361083984375

 EPOCH 861/100000 	 train loss 815.4954833984375

 EPOCH 862/100000 	 train loss 510.3927307128906

 EP


 EPOCH 1009/100000 	 train loss 546.9989013671875

 EPOCH 1010/100000 	 train loss 683.06201171875

 EPOCH 1011/100000 	 train loss 692.752685546875

 EPOCH 1012/100000 	 train loss 697.0208740234375

 EPOCH 1013/100000 	 train loss 654.5313110351562

 EPOCH 1014/100000 	 train loss 721.559814453125

 EPOCH 1015/100000 	 train loss 662.5269775390625

 EPOCH 1016/100000 	 train loss 799.930419921875

 EPOCH 1017/100000 	 train loss 722.9017944335938

 EPOCH 1018/100000 	 train loss 480.76043701171875

 EPOCH 1019/100000 	 train loss 545.3192138671875

 EPOCH 1020/100000 	 train loss 571.0099487304688

 EPOCH 1021/100000 	 train loss 672.1036987304688

 EPOCH 1022/100000 	 train loss 578.2691650390625

 EPOCH 1023/100000 	 train loss 587.70947265625

 EPOCH 1024/100000 	 train loss 749.8400268554688

 EPOCH 1025/100000 	 train loss 619.9800415039062

 EPOCH 1026/100000 	 train loss 685.5994873046875

 EPOCH 1027/100000 	 train loss 686.833984375

 EPOCH 1028/100000 	 train loss 597.1019


 EPOCH 1172/100000 	 train loss 1053.9462890625

 EPOCH 1173/100000 	 train loss 750.0146484375

 EPOCH 1174/100000 	 train loss 761.2122802734375

 EPOCH 1175/100000 	 train loss 1035.613037109375

 EPOCH 1176/100000 	 train loss 1100.1395263671875

 EPOCH 1177/100000 	 train loss 1299.385986328125

 EPOCH 1178/100000 	 train loss 841.56982421875

 EPOCH 1179/100000 	 train loss 1194.5191650390625

 EPOCH 1180/100000 	 train loss 961.8569946289062

 EPOCH 1181/100000 	 train loss 776.805419921875

 EPOCH 1182/100000 	 train loss 1070.294677734375

 EPOCH 1183/100000 	 train loss 1319.6810302734375

 EPOCH 1184/100000 	 train loss 782.6439819335938

 EPOCH 1185/100000 	 train loss 559.6221313476562

 EPOCH 1186/100000 	 train loss 601.5953369140625

 EPOCH 1187/100000 	 train loss 696.96484375

 EPOCH 1188/100000 	 train loss 766.7312622070312

 EPOCH 1189/100000 	 train loss 657.9044189453125

 EPOCH 1190/100000 	 train loss 607.7194213867188

 EPOCH 1191/100000 	 train loss 952.0270


 EPOCH 1336/100000 	 train loss 562.6553344726562

 EPOCH 1337/100000 	 train loss 449.4044189453125

 EPOCH 1338/100000 	 train loss 460.4622802734375

 EPOCH 1339/100000 	 train loss 483.6835632324219

 EPOCH 1340/100000 	 train loss 422.87969970703125

 EPOCH 1341/100000 	 train loss 639.8779296875

 EPOCH 1342/100000 	 train loss 448.0265197753906

 EPOCH 1343/100000 	 train loss 590.2454833984375

 EPOCH 1344/100000 	 train loss 549.618896484375

 EPOCH 1345/100000 	 train loss 523.4630126953125

 EPOCH 1346/100000 	 train loss 569.2945556640625

 EPOCH 1347/100000 	 train loss 432.14813232421875

 EPOCH 1348/100000 	 train loss 494.6002502441406

 EPOCH 1349/100000 	 train loss 516.5083618164062

 EPOCH 1350/100000 	 train loss 685.0679931640625

 EPOCH 1351/100000 	 train loss 597.47900390625

 EPOCH 1352/100000 	 train loss 678.3724365234375

 EPOCH 1353/100000 	 train loss 576.8902587890625

 EPOCH 1354/100000 	 train loss 591.125732421875

 EPOCH 1355/100000 	 train loss 571


 EPOCH 1499/100000 	 train loss 811.5043334960938

 EPOCH 1500/100000 	 train loss 565.6627807617188

 EPOCH 1501/100000 	 train loss 506.90362548828125

 EPOCH 1502/100000 	 train loss 513.4891967773438

 EPOCH 1503/100000 	 train loss 475.5846252441406

 EPOCH 1504/100000 	 train loss 440.4693603515625

 EPOCH 1505/100000 	 train loss 753.9906005859375

 EPOCH 1506/100000 	 train loss 424.3263854980469

 EPOCH 1507/100000 	 train loss 545.0526123046875

 EPOCH 1508/100000 	 train loss 424.2042541503906

 EPOCH 1509/100000 	 train loss 770.2837524414062

 EPOCH 1510/100000 	 train loss 377.0894470214844

 EPOCH 1511/100000 	 train loss 573.48095703125

 EPOCH 1512/100000 	 train loss 735.9659423828125

 EPOCH 1513/100000 	 train loss 473.0986633300781

 EPOCH 1514/100000 	 train loss 521.0299682617188

 EPOCH 1515/100000 	 train loss 377.247802734375

 EPOCH 1516/100000 	 train loss 504.938720703125

 EPOCH 1517/100000 	 train loss 460.7958984375

 EPOCH 1518/100000 	 train loss 534.


 EPOCH 1661/100000 	 train loss 526.34423828125

 EPOCH 1662/100000 	 train loss 517.7504272460938

 EPOCH 1663/100000 	 train loss 573.4676513671875

 EPOCH 1664/100000 	 train loss 487.08551025390625

 EPOCH 1665/100000 	 train loss 496.968505859375

 EPOCH 1666/100000 	 train loss 680.8284912109375

 EPOCH 1667/100000 	 train loss 551.2289428710938

 EPOCH 1668/100000 	 train loss 607.9566040039062

 EPOCH 1669/100000 	 train loss 602.4812622070312

 EPOCH 1670/100000 	 train loss 421.57977294921875

 EPOCH 1671/100000 	 train loss 463.99005126953125

 EPOCH 1672/100000 	 train loss 479.7209777832031

 EPOCH 1673/100000 	 train loss 760.3226318359375

 EPOCH 1674/100000 	 train loss 368.06781005859375

 EPOCH 1675/100000 	 train loss 625.991455078125

 EPOCH 1676/100000 	 train loss 503.74609375

 EPOCH 1677/100000 	 train loss 404.2813720703125

 EPOCH 1678/100000 	 train loss 560.8905639648438

 EPOCH 1679/100000 	 train loss 430.2120361328125

 EPOCH 1680/100000 	 train loss 490


 EPOCH 1823/100000 	 train loss 559.3218994140625

 EPOCH 1824/100000 	 train loss 324.7906799316406

 EPOCH 1825/100000 	 train loss 390.08612060546875

 EPOCH 1826/100000 	 train loss 397.46392822265625

 EPOCH 1827/100000 	 train loss 446.51824951171875

 EPOCH 1828/100000 	 train loss 522.7950439453125

 EPOCH 1829/100000 	 train loss 453.6707763671875

 EPOCH 1830/100000 	 train loss 613.7706909179688

 EPOCH 1831/100000 	 train loss 552.5697631835938

 EPOCH 1832/100000 	 train loss 456.3870544433594

 EPOCH 1833/100000 	 train loss 579.1397705078125

 EPOCH 1834/100000 	 train loss 506.1280517578125

 EPOCH 1835/100000 	 train loss 359.88568115234375

 EPOCH 1836/100000 	 train loss 514.6465454101562

 EPOCH 1837/100000 	 train loss 493.853271484375

 EPOCH 1838/100000 	 train loss 513.845703125

 EPOCH 1839/100000 	 train loss 533.1217651367188

 EPOCH 1840/100000 	 train loss 364.54443359375

 EPOCH 1841/100000 	 train loss 447.47802734375

 EPOCH 1842/100000 	 train loss 433


 EPOCH 1985/100000 	 train loss 1861.8333740234375

 EPOCH 1986/100000 	 train loss 536.1526489257812

 EPOCH 1987/100000 	 train loss 1366.3472900390625

 EPOCH 1988/100000 	 train loss 862.124755859375

 EPOCH 1989/100000 	 train loss 986.4824829101562

 EPOCH 1990/100000 	 train loss 669.4041748046875

 EPOCH 1991/100000 	 train loss 1138.6236572265625

 EPOCH 1992/100000 	 train loss 703.375244140625

 EPOCH 1993/100000 	 train loss 698.324462890625

 EPOCH 1994/100000 	 train loss 722.9371948242188

 EPOCH 1995/100000 	 train loss 671.4268188476562

 EPOCH 1996/100000 	 train loss 1256.9742431640625

 EPOCH 1997/100000 	 train loss 546.7200927734375

 EPOCH 1998/100000 	 train loss 518.9673461914062

 EPOCH 1999/100000 	 train loss 445.7774353027344

 EPOCH 2000/100000 	 train loss 348.7120361328125

 EPOCH 2001/100000 	 train loss 637.1160278320312

 EPOCH 2002/100000 	 train loss 505.57684326171875

 EPOCH 2003/100000 	 train loss 740.3936157226562

 EPOCH 2004/100000 	 train l


 EPOCH 2147/100000 	 train loss 631.3569946289062

 EPOCH 2148/100000 	 train loss 489.24493408203125

 EPOCH 2149/100000 	 train loss 366.36077880859375

 EPOCH 2150/100000 	 train loss 564.0923461914062

 EPOCH 2151/100000 	 train loss 563.3156127929688

 EPOCH 2152/100000 	 train loss 481.146240234375

 EPOCH 2153/100000 	 train loss 339.5690002441406

 EPOCH 2154/100000 	 train loss 368.4131774902344

 EPOCH 2155/100000 	 train loss 495.60943603515625

 EPOCH 2156/100000 	 train loss 502.9500732421875

 EPOCH 2157/100000 	 train loss 491.0130310058594

 EPOCH 2158/100000 	 train loss 456.3201599121094

 EPOCH 2159/100000 	 train loss 457.0640563964844

 EPOCH 2160/100000 	 train loss 504.6390380859375

 EPOCH 2161/100000 	 train loss 447.190673828125

 EPOCH 2162/100000 	 train loss 587.4328002929688

 EPOCH 2163/100000 	 train loss 683.84521484375

 EPOCH 2164/100000 	 train loss 438.5501403808594

 EPOCH 2165/100000 	 train loss 404.1970520019531

 EPOCH 2166/100000 	 train loss


 EPOCH 2309/100000 	 train loss 381.1305847167969

 EPOCH 2310/100000 	 train loss 336.6932067871094

 EPOCH 2311/100000 	 train loss 550.6710815429688

 EPOCH 2312/100000 	 train loss 479.54107666015625

 EPOCH 2313/100000 	 train loss 406.3166809082031

 EPOCH 2314/100000 	 train loss 452.55767822265625

 EPOCH 2315/100000 	 train loss 410.7055969238281

 EPOCH 2316/100000 	 train loss 313.98089599609375

 EPOCH 2317/100000 	 train loss 391.93731689453125

 EPOCH 2318/100000 	 train loss 476.0872802734375

 EPOCH 2319/100000 	 train loss 396.97418212890625

 EPOCH 2320/100000 	 train loss 235.16043090820312

 EPOCH 2321/100000 	 train loss 403.5345764160156

 EPOCH 2322/100000 	 train loss 313.8046569824219

 EPOCH 2323/100000 	 train loss 430.6938781738281

 EPOCH 2324/100000 	 train loss 446.78253173828125

 EPOCH 2325/100000 	 train loss 415.74798583984375

 EPOCH 2326/100000 	 train loss 401.6485290527344

 EPOCH 2327/100000 	 train loss 379.081298828125

 EPOCH 2328/100000 	 tr


 EPOCH 2471/100000 	 train loss 622.5709838867188

 EPOCH 2472/100000 	 train loss 465.69610595703125

 EPOCH 2473/100000 	 train loss 483.45123291015625

 EPOCH 2474/100000 	 train loss 515.1436767578125

 EPOCH 2475/100000 	 train loss 416.7364807128906

 EPOCH 2476/100000 	 train loss 273.9610290527344

 EPOCH 2477/100000 	 train loss 334.32275390625

 EPOCH 2478/100000 	 train loss 408.662353515625

 EPOCH 2479/100000 	 train loss 813.9324951171875

 EPOCH 2480/100000 	 train loss 449.1861267089844

 EPOCH 2481/100000 	 train loss 526.4420776367188

 EPOCH 2482/100000 	 train loss 550.1732177734375

 EPOCH 2483/100000 	 train loss 493.678955078125

 EPOCH 2484/100000 	 train loss 524.3250732421875

 EPOCH 2485/100000 	 train loss 560.2427978515625

 EPOCH 2486/100000 	 train loss 625.52734375

 EPOCH 2487/100000 	 train loss 414.9507751464844

 EPOCH 2488/100000 	 train loss 464.874267578125

 EPOCH 2489/100000 	 train loss 368.2599182128906

 EPOCH 2490/100000 	 train loss 568.65


 EPOCH 2632/100000 	 train loss 258.2296447753906

 EPOCH 2633/100000 	 train loss 427.9398193359375

 EPOCH 2634/100000 	 train loss 421.6357727050781

 EPOCH 2635/100000 	 train loss 335.8719787597656

 EPOCH 2636/100000 	 train loss 425.9888000488281

 EPOCH 2637/100000 	 train loss 360.30419921875

 EPOCH 2638/100000 	 train loss 351.1875

 EPOCH 2639/100000 	 train loss 361.8078308105469

 EPOCH 2640/100000 	 train loss 391.9091491699219

 EPOCH 2641/100000 	 train loss 390.85968017578125

 EPOCH 2642/100000 	 train loss 387.8278503417969

 EPOCH 2643/100000 	 train loss 399.24951171875

 EPOCH 2644/100000 	 train loss 389.214599609375

 EPOCH 2645/100000 	 train loss 300.2754211425781

 EPOCH 2646/100000 	 train loss 402.32354736328125

 EPOCH 2647/100000 	 train loss 373.6698303222656

 EPOCH 2648/100000 	 train loss 591.1339111328125

 EPOCH 2649/100000 	 train loss 588.7412719726562

 EPOCH 2650/100000 	 train loss 434.50506591796875

 EPOCH 2651/100000 	 train loss 427.07476


 EPOCH 2795/100000 	 train loss 698.359619140625

 EPOCH 2796/100000 	 train loss 438.713134765625

 EPOCH 2797/100000 	 train loss 534.8048095703125

 EPOCH 2798/100000 	 train loss 572.8543090820312

 EPOCH 2799/100000 	 train loss 456.42529296875

 EPOCH 2800/100000 	 train loss 402.3944091796875

 EPOCH 2801/100000 	 train loss 434.67144775390625

 EPOCH 2802/100000 	 train loss 439.78558349609375

 EPOCH 2803/100000 	 train loss 478.49688720703125

 EPOCH 2804/100000 	 train loss 635.9566040039062

 EPOCH 2805/100000 	 train loss 726.4514770507812

 EPOCH 2806/100000 	 train loss 466.60675048828125

 EPOCH 2807/100000 	 train loss 521.1470947265625

 EPOCH 2808/100000 	 train loss 783.0949096679688

 EPOCH 2809/100000 	 train loss 477.1133117675781

 EPOCH 2810/100000 	 train loss 340.52325439453125

 EPOCH 2811/100000 	 train loss 419.4107360839844

 EPOCH 2812/100000 	 train loss 446.24615478515625

 EPOCH 2813/100000 	 train loss 475.4090270996094

 EPOCH 2814/100000 	 train l


 EPOCH 2956/100000 	 train loss 375.423095703125

 EPOCH 2957/100000 	 train loss 388.9021301269531

 EPOCH 2958/100000 	 train loss 422.22064208984375

 EPOCH 2959/100000 	 train loss 361.69427490234375

 EPOCH 2960/100000 	 train loss 334.1167907714844

 EPOCH 2961/100000 	 train loss 402.48809814453125

 EPOCH 2962/100000 	 train loss 326.4889221191406

 EPOCH 2963/100000 	 train loss 386.16461181640625

 EPOCH 2964/100000 	 train loss 274.6260986328125

 EPOCH 2965/100000 	 train loss 409.1802978515625

 EPOCH 2966/100000 	 train loss 422.06903076171875

 EPOCH 2967/100000 	 train loss 385.62255859375

 EPOCH 2968/100000 	 train loss 315.3655090332031

 EPOCH 2969/100000 	 train loss 318.8948059082031

 EPOCH 2970/100000 	 train loss 339.623291015625

 EPOCH 2971/100000 	 train loss 347.4435119628906

 EPOCH 2972/100000 	 train loss 460.405029296875

 EPOCH 2973/100000 	 train loss 548.6712646484375

 EPOCH 2974/100000 	 train loss 282.31463623046875

 EPOCH 2975/100000 	 train lo


 EPOCH 3119/100000 	 train loss 559.9244995117188

 EPOCH 3120/100000 	 train loss 730.243408203125

 EPOCH 3121/100000 	 train loss 676.806884765625

 EPOCH 3122/100000 	 train loss 312.36212158203125

 EPOCH 3123/100000 	 train loss 630.412109375

 EPOCH 3124/100000 	 train loss 356.6800537109375

 EPOCH 3125/100000 	 train loss 523.0213012695312

 EPOCH 3126/100000 	 train loss 809.9832763671875

 EPOCH 3127/100000 	 train loss 729.6817626953125

 EPOCH 3128/100000 	 train loss 411.6448974609375

 EPOCH 3129/100000 	 train loss 800.7451782226562

 EPOCH 3130/100000 	 train loss 683.1256713867188

 EPOCH 3131/100000 	 train loss 576.5159301757812

 EPOCH 3132/100000 	 train loss 479.6126708984375

 EPOCH 3133/100000 	 train loss 455.28790283203125

 EPOCH 3134/100000 	 train loss 489.3337097167969

 EPOCH 3135/100000 	 train loss 376.0902099609375

 EPOCH 3136/100000 	 train loss 433.88812255859375

 EPOCH 3137/100000 	 train loss 1535.2786865234375

 EPOCH 3138/100000 	 train loss 


 EPOCH 3281/100000 	 train loss 402.1661071777344

 EPOCH 3282/100000 	 train loss 350.33880615234375

 EPOCH 3283/100000 	 train loss 329.80975341796875

 EPOCH 3284/100000 	 train loss 455.96405029296875

 EPOCH 3285/100000 	 train loss 375.8516845703125

 EPOCH 3286/100000 	 train loss 962.2531127929688

 EPOCH 3287/100000 	 train loss 383.3797912597656

 EPOCH 3288/100000 	 train loss 565.6050415039062

 EPOCH 3289/100000 	 train loss 469.15521240234375

 EPOCH 3290/100000 	 train loss 571.968017578125

 EPOCH 3291/100000 	 train loss 488.50640869140625

 EPOCH 3292/100000 	 train loss 567.2404174804688

 EPOCH 3293/100000 	 train loss 362.7524719238281

 EPOCH 3294/100000 	 train loss 500.1913757324219

 EPOCH 3295/100000 	 train loss 345.4017333984375

 EPOCH 3296/100000 	 train loss 440.7544250488281

 EPOCH 3297/100000 	 train loss 367.028564453125

 EPOCH 3298/100000 	 train loss 455.8704833984375

 EPOCH 3299/100000 	 train loss 375.46234130859375

 EPOCH 3300/100000 	 train


 EPOCH 3443/100000 	 train loss 444.3337707519531

 EPOCH 3444/100000 	 train loss 523.83740234375

 EPOCH 3445/100000 	 train loss 503.2442626953125

 EPOCH 3446/100000 	 train loss 329.1167907714844

 EPOCH 3447/100000 	 train loss 524.8079223632812

 EPOCH 3448/100000 	 train loss 418.06103515625

 EPOCH 3449/100000 	 train loss 492.7156982421875

 EPOCH 3450/100000 	 train loss 501.1773681640625

 EPOCH 3451/100000 	 train loss 514.151123046875

 EPOCH 3452/100000 	 train loss 414.587890625

 EPOCH 3453/100000 	 train loss 866.2990112304688

 EPOCH 3454/100000 	 train loss 563.8518676757812

 EPOCH 3455/100000 	 train loss 418.3790588378906

 EPOCH 3456/100000 	 train loss 494.6688537597656

 EPOCH 3457/100000 	 train loss 394.0792236328125

 EPOCH 3458/100000 	 train loss 529.6693725585938

 EPOCH 3459/100000 	 train loss 379.669921875

 EPOCH 3460/100000 	 train loss 622.8380737304688

 EPOCH 3461/100000 	 train loss 488.11431884765625

 EPOCH 3462/100000 	 train loss 466.993164


 EPOCH 3606/100000 	 train loss 533.1477661132812

 EPOCH 3607/100000 	 train loss 581.795166015625

 EPOCH 3608/100000 	 train loss 278.5618896484375

 EPOCH 3609/100000 	 train loss 334.72357177734375

 EPOCH 3610/100000 	 train loss 521.6918334960938

 EPOCH 3611/100000 	 train loss 458.4723815917969

 EPOCH 3612/100000 	 train loss 406.3973388671875

 EPOCH 3613/100000 	 train loss 338.6976623535156

 EPOCH 3614/100000 	 train loss 310.4305419921875

 EPOCH 3615/100000 	 train loss 381.52032470703125

 EPOCH 3616/100000 	 train loss 507.7805480957031

 EPOCH 3617/100000 	 train loss 360.2529602050781

 EPOCH 3618/100000 	 train loss 439.37322998046875

 EPOCH 3619/100000 	 train loss 337.4724426269531

 EPOCH 3620/100000 	 train loss 336.3214111328125

 EPOCH 3621/100000 	 train loss 671.417236328125

 EPOCH 3622/100000 	 train loss 340.4980163574219

 EPOCH 3623/100000 	 train loss 380.1986999511719

 EPOCH 3624/100000 	 train loss 413.6297912597656

 EPOCH 3625/100000 	 train lo


 EPOCH 3769/100000 	 train loss 449.56549072265625

 EPOCH 3770/100000 	 train loss 594.296142578125

 EPOCH 3771/100000 	 train loss 617.17822265625

 EPOCH 3772/100000 	 train loss 979.3893432617188

 EPOCH 3773/100000 	 train loss 547.5576171875

 EPOCH 3774/100000 	 train loss 609.7576904296875

 EPOCH 3775/100000 	 train loss 600.368408203125

 EPOCH 3776/100000 	 train loss 501.4836730957031

 EPOCH 3777/100000 	 train loss 852.98388671875

 EPOCH 3778/100000 	 train loss 516.59765625

 EPOCH 3779/100000 	 train loss 598.0301513671875

 EPOCH 3780/100000 	 train loss 495.67230224609375

 EPOCH 3781/100000 	 train loss 566.4197998046875

 EPOCH 3782/100000 	 train loss 450.817138671875

 EPOCH 3783/100000 	 train loss 391.5381164550781

 EPOCH 3784/100000 	 train loss 602.26220703125

 EPOCH 3785/100000 	 train loss 1071.6077880859375

 EPOCH 3786/100000 	 train loss 609.8865966796875

 EPOCH 3787/100000 	 train loss 349.9799499511719

 EPOCH 3788/100000 	 train loss 1071.3364257


 EPOCH 3931/100000 	 train loss 331.555419921875

 EPOCH 3932/100000 	 train loss 477.1939697265625

 EPOCH 3933/100000 	 train loss 340.26092529296875

 EPOCH 3934/100000 	 train loss 234.693359375

 EPOCH 3935/100000 	 train loss 445.5726013183594

 EPOCH 3936/100000 	 train loss 417.92919921875

 EPOCH 3937/100000 	 train loss 277.05731201171875

 EPOCH 3938/100000 	 train loss 292.1360778808594

 EPOCH 3939/100000 	 train loss 332.225830078125

 EPOCH 3940/100000 	 train loss 353.4809875488281

 EPOCH 3941/100000 	 train loss 418.27972412109375

 EPOCH 3942/100000 	 train loss 295.73638916015625

 EPOCH 3943/100000 	 train loss 343.5893249511719

 EPOCH 3944/100000 	 train loss 418.3992919921875

 EPOCH 3945/100000 	 train loss 355.6845703125

 EPOCH 3946/100000 	 train loss 334.0557556152344

 EPOCH 3947/100000 	 train loss 362.88092041015625

 EPOCH 3948/100000 	 train loss 314.273193359375

 EPOCH 3949/100000 	 train loss 362.369384765625

 EPOCH 3950/100000 	 train loss 378.50


 EPOCH 4093/100000 	 train loss 412.37445068359375

 EPOCH 4094/100000 	 train loss 376.7083435058594

 EPOCH 4095/100000 	 train loss 405.9243469238281

 EPOCH 4096/100000 	 train loss 337.27923583984375

 EPOCH 4097/100000 	 train loss 326.16021728515625

 EPOCH 4098/100000 	 train loss 385.6965637207031

 EPOCH 4099/100000 	 train loss 368.3167724609375

 EPOCH 4100/100000 	 train loss 337.08154296875

 EPOCH 4101/100000 	 train loss 372.5851745605469

 EPOCH 4102/100000 	 train loss 494.1510009765625

 EPOCH 4103/100000 	 train loss 348.4629211425781

 EPOCH 4104/100000 	 train loss 418.87646484375

 EPOCH 4105/100000 	 train loss 343.9559020996094

 EPOCH 4106/100000 	 train loss 349.69952392578125

 EPOCH 4107/100000 	 train loss 293.5696105957031

 EPOCH 4108/100000 	 train loss 278.543701171875

 EPOCH 4109/100000 	 train loss 538.7564697265625

 EPOCH 4110/100000 	 train loss 352.5230712890625

 EPOCH 4111/100000 	 train loss 410.08746337890625

 EPOCH 4112/100000 	 train los


 EPOCH 4254/100000 	 train loss 317.9032287597656

 EPOCH 4255/100000 	 train loss 354.294189453125

 EPOCH 4256/100000 	 train loss 440.91864013671875

 EPOCH 4257/100000 	 train loss 329.2049255371094

 EPOCH 4258/100000 	 train loss 419.46148681640625

 EPOCH 4259/100000 	 train loss 343.89990234375

 EPOCH 4260/100000 	 train loss 289.1492614746094

 EPOCH 4261/100000 	 train loss 321.39984130859375

 EPOCH 4262/100000 	 train loss 399.98614501953125

 EPOCH 4263/100000 	 train loss 267.950439453125

 EPOCH 4264/100000 	 train loss 365.11553955078125

 EPOCH 4265/100000 	 train loss 453.2189025878906

 EPOCH 4266/100000 	 train loss 380.897216796875

 EPOCH 4267/100000 	 train loss 371.8917541503906

 EPOCH 4268/100000 	 train loss 362.8155822753906

 EPOCH 4269/100000 	 train loss 416.0416564941406

 EPOCH 4270/100000 	 train loss 374.61419677734375

 EPOCH 4271/100000 	 train loss 389.6842041015625

 EPOCH 4272/100000 	 train loss 467.3454284667969

 EPOCH 4273/100000 	 train lo


 EPOCH 4416/100000 	 train loss 301.8504638671875

 EPOCH 4417/100000 	 train loss 324.2147216796875

 EPOCH 4418/100000 	 train loss 326.4446105957031

 EPOCH 4419/100000 	 train loss 378.5169372558594

 EPOCH 4420/100000 	 train loss 291.2576904296875

 EPOCH 4421/100000 	 train loss 370.8482666015625

 EPOCH 4422/100000 	 train loss 304.5286865234375

 EPOCH 4423/100000 	 train loss 347.51318359375

 EPOCH 4424/100000 	 train loss 378.71331787109375

 EPOCH 4425/100000 	 train loss 410.9040222167969

 EPOCH 4426/100000 	 train loss 798.3575439453125

 EPOCH 4427/100000 	 train loss 479.4822998046875

 EPOCH 4428/100000 	 train loss 408.77813720703125

 EPOCH 4429/100000 	 train loss 420.5904846191406

 EPOCH 4430/100000 	 train loss 373.214111328125

 EPOCH 4431/100000 	 train loss 311.00689697265625


In [12]:
import torchvision
from torchvision import transforms
import random

def show_image(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

encoder.eval()
decoder.eval()

with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs 
    images = torch.from_numpy(np.round(np.random.rand(100, 1, im_size[0], im_size[1]))).float()
    labels = np.zeros(15, dtype = float)
    
    # Setting labels for each attribute randomly based on number of classes for each
    randnum = random.randint(0,3)
    labels[randnum] = 1.0
    
    randnum = random.randint(0,3))
    labels[randnum + 4] = 1.0
    
    randnum = random.randint(0,3))
    labels[randnum + 8] = 1.0
    
    randnum = random.randint(0,2))
    labels[randnum + 12] = 1.0


    images = images.to(device)
    labels = torch.from_numpy(labels)
    latent = encoder(images, labels.to(device))
    latent = latent.cpu()

    #mean = latent.mean(dim=0)
    #print(mean)
    #std = (latent - mean).pow(2).mean(dim=0).sqrt()
    #print(std)

    # sample latent vectors from the normal distribution
    latent = torch.from_numpy(np.round(np.random.rand(128, d))).float()
    latent = torch.cat((latent, labels), dim=1)
    # reconstruct images from the random latent vectors
    latent = latent.to(device)
    img_recon = decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(40, 16.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()


TypeError: forward() missing 1 required positional argument: 'attributes'

In [None]:
fig, ax = plt.subplots(10, 2, figsize=(10, 40))
fig.set_tight_layout(True)

c = 0
for data in testing_generator: # with "_" we just ignore the labels (the second element of the dataloader tuple)
    # Move tensor to the proper device
    if(c>=10):
        break
    image_batch = data["image"]


    #print("Random data shape", np.shape(random_data), "; data[labels] ", np.shape(data["labels"]))
    latent = data["labels"].float()
    latent = latent.to(device)
   
    
    img_recon = decoder(latent)

    img_recon = img_recon.detach().cpu().numpy()
    
    ax[c, 0].imshow(image_batch[0, 0, :, :])
    ax[c, 1].imshow(img_recon[0, 0, :, :])
    
    c += 1
    

plt.tight_layout()
    
    

In [None]:
print(torch.__version__)