<a href="https://colab.research.google.com/github/takhtardeshirsoheib/soheib/blob/master/CT_Seg.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:

class Config:
    
    def __init__(self):
        #network configure
        self.InputCh=3
        self.ScaleRatio = 2
        self.ConvSize = 3
        self.pad = 1#(self.ConvSize - 1) / 2 
        self.MaxLv = 5
        self.ChNum = [self.InputCh,64]
        for i in range(self.MaxLv-1):
            self.ChNum.append(self.ChNum[-1]*2)
        #data configure
        self.pascal = "/content/drive/My Drive/CT/CT1/"
        self.bsds = "../BSR/BSDS500/data/images/"
        #self.imagelist = "ImageSets/Segmentation/train.txt"
        self.BatchSize = 6
        self.Shuffle = True
        self.LoadThread = 4
        self.inputsize = [224,224]
        #partition configure
        self.K = 64
        #training configure
        self.init_lr = 0.05
        self.lr_decay = 0.1
        self.lr_decay_iter = 1000
        self.max_iter = 50000
        self.cuda_dev = 0 
        self.cuda_dev_list = "0,1"
        self.check_iter = 1000
        #Ncuts Loss configure
        self.radius = 4
        self.sigmaI = 10
        self.sigmaX = 4
        #testing configure
        self.model_tested = "./checkpoint_8_23_13_0_epoch_2000"
        #color library
        self.color_lib = []
        for r in range(0,256,128):
            for g in range(0,256,128):
                for b in range(0,256,128):
                    self.color_lib.append((r,g,b))


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [30]:
from PIL import Image
import PIL
import torch
import torch.utils.data as Data
import os
import glob
import numpy as np
import pdb
import math
import cupy as cp
PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True
config = Config()

class DataLoader():
    #initialization
    #datapath : the data folder of bsds500
    #mode : train/test/val
    def __init__(self, datapath,mode):
        #image container
        self.path = "/content/drive/My Drive/CT/CT1/"
        self.Imgpath = os.listdir(self.path)
        self.raw_data = []
        self.mode = mode
        print(self.Imgpath)
        #navigate to the image directory
        #images_path = os.path.join(datapath,'images')
        train_image_path = os.path.join(datapath,mode)
        file_list = []
        if(mode != "train"):
            train_image_regex = os.path.join(train_image_path, '*.jpg')
            file_list = glob.glob(train_image_regex)
        #find all the images
        else:
            for i in range (len(self.Imgpath)):
              file_list.append(self.path + self.Imgpath[i])    
        #load the images
        for file_name in file_list:
            image = Image.open(file_name)
            if image.mode != "RGB":
              image = image.convert("RGB")
            self.raw_data.append(np.array(image.resize((config.inputsize[0],config.inputsize[1]),Image.BILINEAR)))
        #resize and align
        self.scale()
        #normalize
        self.transfer()
        
        #calculate weights by 2
        if(mode == "train"):
            self.dataset = self.get_dataset(self.raw_data, self.raw_data.shape,75)
        else:
            self.dataset = self.get_dataset(self.raw_data, self.raw_data.shape,75)
    
    def scale(self):
        for i in range(len(self.raw_data)):
            image = self.raw_data[i]
            self.raw_data[i] = np.stack((image[:,:,0],image[:,:,1],image[:,:,2]),axis = 0)
        self.raw_data = np.stack(self.raw_data,axis = 0)

    def transfer(self):
        #just for RGB 8-bit color
        self.raw_data = self.raw_data.astype(np.float)
        #for i in range(self.raw_data.shape[0]):
        #    Image.fromarray(self.raw_data[i].swapaxes(0,-1).astype(np.uint8)).save("./reconstruction/input_"+str(i)+".jpg")

    def torch_loader(self):
        return Data.DataLoader(
                                self.dataset,
                                batch_size = config.BatchSize,
                                shuffle = config.Shuffle,
                                num_workers = config.LoadThread,
                                pin_memory = True,
                            )

    def cal_weight(self,raw_data,shape):
        #According to the weight formula, when Euclidean distance < r,the weight is 0, so reduce the dissim matrix size to radius-1 to save time and space.
        print("calculating weights.")

        dissim = cp.zeros((shape[0],shape[1],shape[2],shape[3],(config.radius-1)*2+1,(config.radius-1)*2+1))
        data = cp.asarray(raw_data)
        padded_data = cp.pad(data,((0,0),(0,0),(config.radius-1,config.radius-1),(config.radius-1,config.radius-1)),'constant')
        for m in range(2*(config.radius-1)+1):
            for n in range(2*(config.radius-1)+1):
                dissim[:,:,:,:,m,n] = data-padded_data[:,:,m:shape[2]+m,n:shape[3]+n]
        #for i in range(dissim.shape[0]):
        #dissim = -cp.power(dissim,2).sum(1,keepdims = True)/config.sigmaI/config.sigmaI
        temp_dissim = cp.exp(-cp.power(dissim,2).sum(1,keepdims = True)/config.sigmaI**2)
        dist = cp.zeros((2*(config.radius-1)+1,2*(config.radius-1)+1))
        for m in range(1-config.radius,config.radius):
            for n in range(1-config.radius,config.radius):
                if m**2+n**2<config.radius**2:
                    dist[m+config.radius-1,n+config.radius-1] = cp.exp(-(m**2+n**2)/config.sigmaX**2)
        #for m in range(0,config.radius-1):
        #    temp_dissim[:,:,m,:,0:config.radius-1-m,:]=0.0
        #    temp_dissim[:,:,-1-m,:,m-config.radius+1:-1,:]=0.0
        #    temp_dissim[:,:,:,m,:,0:config.radius-1-m]=0.0
        #    temp_dissim[:,:,:,-1-m,:,m-config.radius+1:-1]=0.0
        print("weight calculated.")
        res = cp.multiply(temp_dissim,dist)
        #for m in range(50,70):

        #    print(m)
        #    for n in range(50,70):
        #        print(dissim[5,0,m,n])
        #print(dist)
        return res

    def get_dataset(self,raw_data,shape,batch_size):
        dataset = []
        for batch_id in range(0,shape[0],batch_size):
            print(batch_id)
            batch = raw_data[batch_id:min(shape[0],batch_id+batch_size)]
            if(self.mode == "train"):
                tmp_weight = self.cal_weight(batch,batch.shape)
                weight = cp.asnumpy(tmp_weight)
                dataset.append(Data.TensorDataset(torch.from_numpy(batch/256).float(),torch.from_numpy(weight).float()))
                del tmp_weight
            else:
                dataset.append(Data.TensorDataset(torch.from_numpy(batch/256).float()))
        cp.get_default_memory_pool().free_all_blocks()
        return Data.ConcatDataset(dataset)


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import pdb
config = Config()

class WNet(torch.nn.Module):
    def __init__(self):
        super(WNet, self).__init__()
        self.feature1 = []
        self.feature2 = []
        bias = True
        #U-Net1
        #module1
        self.module = []
        self.maxpool1 = []
        self.uconv1 = []
        self.module.append(
            self.add_conv_stage(config.ChNum[0],config.ChNum[1],config.ConvSize,padding=config.pad,seperable=False)   
        )
        
        #module2-5
        for i in range(2,config.MaxLv+1):
            self.module.append(self.add_conv_stage(config.ChNum[i-1],config.ChNum[i],config.ConvSize,padding=config.pad))
            
        #module6-8
        for i in range(config.MaxLv-1,1,-1):
            self.module.append(self.add_conv_stage(2*config.ChNum[i],config.ChNum[i],config.ConvSize,padding=config.pad))
        #module9
        self.module.append(
            self.add_conv_stage(2*config.ChNum[1],config.ChNum[1],config.ConvSize,padding=config.pad,seperable=False)
        )
        #module1-4
        for i in range(config.MaxLv-1):
            self.maxpool1.append(nn.MaxPool2d(config.ScaleRatio))
        #module5-8
        for i in range(config.MaxLv,1,-1):
            self.uconv1.append(nn.ConvTranspose2d(config.ChNum[i],config.ChNum[i-1],config.ScaleRatio,config.ScaleRatio,bias = True))
        self.predconv = nn.Conv2d(config.ChNum[1],config.K,1,bias = bias)
        self.pad = nn.ConstantPad2d(config.radius-1,0)
        self.softmax = nn.Softmax2d()
        self.module = torch.nn.ModuleList(self.module)
        self.maxpool1 = torch.nn.ModuleList(self.maxpool1)
        self.uconv1 = torch.nn.ModuleList(self.uconv1)
        #self.loss = NcutsLoss()
    def add_conv_stage(self,dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=True, seperable=True):
        if seperable:
            return nn.Sequential(
                nn.Conv2d(dim_in,dim_out,1,bias = bias),
                nn.Conv2d(dim_out,dim_out,kernel_size,padding = padding,groups = dim_out,bias = bias),
                nn.ReLU(),
                nn.BatchNorm2d(dim_out),
                nn.Conv2d(dim_out,dim_out,1,bias = bias),
                nn.Conv2d(dim_out,dim_out,kernel_size,padding = padding,groups = dim_out,bias = bias),
                nn.ReLU(),
                nn.BatchNorm2d(dim_out),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(dim_in,dim_out,kernel_size,padding = padding,bias = bias),
                nn.ReLU(),
                nn.BatchNorm2d(dim_out),
                nn.Conv2d(dim_out,dim_out,kernel_size,padding = padding,bias = bias),
                nn.ReLU(),
                nn.BatchNorm2d(dim_out),
            )
    def forward(self,x):
        self.feature1 = [x]
        #U-Net1
        self.feature1.append(self.module[0](x))
        for i in range(1,config.MaxLv):
            tempf = self.maxpool1[i-1](self.feature1[-1])
            self.feature1.append(self.module[i](tempf))
        for i in range(config.MaxLv,2*config.MaxLv-2):
            tempf = self.uconv1[i-config.MaxLv](self.feature1[-1])
            tempf = torch.cat((self.feature1[2*config.MaxLv-i-1],tempf),dim=1)
            self.feature1.append(self.module[i](tempf))
        tempf = self.uconv1[-1](self.feature1[-1])
        tempf = torch.cat((self.feature1[1],tempf),dim=1)
        tempf = self.module[-1](tempf)
        tempf = self.predconv(tempf)
        self.feature2 = [self.softmax(tempf)]
        return [self.feature2[0],self.pad(self.feature2[0])]
        #self.feature2.append(self.loss(self.feature2[0],self.feature2[1],w,sw))
        #U-Net2
        
        '''tempf = self.conv2[0](self.feature2[-1])
        tempf = self.ReLU2[0](tempf)
        tempf = self.bn2[0](tempf)
        tempf = self.conv2[1](tempf)
        tempf = self.ReLU2[1](tempf)
        self.feature2.append(self.bn2[1](tempf))

        for i in range(1,config.MaxLv):
            tempf = self.maxpool2[i-1](self.feature2[-1])
            tempf = self.conv2[4*i-2](tempf)
            tempf = self.conv2[4*i-1](tempf)
            tempf = self.ReLU2[2*i](tempf)
            tempf = self.bn2[2*i](tempf)
            tempf = self.conv2[4*i](tempf)
            tempf = self.conv2[4*i+1](tempf)
            tempf = self.ReLU2[2*i+1](tempf)
            
            self.feature2.append(self.bn2[2*i+1](tempf))
        for i in range(config.MaxLv,2*config.MaxLv-2):
            tempf = self.uconv2[i-config.MaxLv](self.feature2[-1])
            tempf = torch.cat((self.feature2[2*config.MaxLv-i-1],tempf),dim=1)
            tempf = self.conv2[4*i-2](tempf)
            tempf = self.conv2[4*i-1](tempf)
            tempf = self.ReLU2[2*i](tempf)
            tempf = self.bn2[2*i](tempf)
            tempf = self.conv2[4*i](tempf)
            tempf = self.conv2[4*i+1](tempf)
            tempf = self.ReLU2[2*i+1](tempf)
            tempf = self.bn2[2*i+1](tempf)            
            self.feature2.append(tempf)
        tempf = self.uconv2[config.MaxLv-2](self.feature2[-1])
        tempf = torch.cat((self.feature2[1],tempf),dim=1)
        tempf = self.conv2[-2](tempf)
        tempf = self.ReLU2[4*config.MaxLv-4](tempf)
        tempf = self.bn2[4*config.MaxLv-4](tempf)
        tempf = self.conv2[-1](tempf)
        tempf = self.ReLU2[4*config.MaxLv-3](tempf)
        tempf = self.bn2[4*config.MaxLv-3](tempf)            
        self.feature2.append(tempf)
        tempf = self.reconsconv(self.feature2[-1])
        tempf = self.ReLU2[-1](tempf)
        self.feature2[-1] = self.bn2[-1](tempf)
        '''



config = Config()
def add_conv_stage(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=True, useBN=False):
  if useBN:
    return nn.Sequential(
      nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
      nn.BatchNorm2d(dim_out),
      nn.LeakyReLU(0.1),
      nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
      nn.BatchNorm2d(dim_out),
      nn.LeakyReLU(0.1)
    )
  else:
    return nn.Sequential(
      nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
      nn.ReLU(),
      nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
      nn.ReLU()
    )

def add_merge_stage(ch_coarse, ch_fine, in_coarse, in_fine, upsample):
  conv = nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False)
  torch.cat(conv, in_fine)

  return nn.Sequential(
    nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False)
  )
  upsample(in_coarse)

def upsample(ch_coarse, ch_fine):
  return nn.Sequential(
    nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False),
    nn.ReLU()
  )

class Net(nn.Module):
  def __init__(self, useBN=False):
    super(Net, self).__init__()

    self.conv1   = add_conv_stage(config.InputCh, 32, useBN=useBN)
    self.conv2   = add_conv_stage(32, 64, useBN=useBN)
    self.conv3   = add_conv_stage(64, 128, useBN=useBN)
    self.conv4   = add_conv_stage(128, 256, useBN=useBN)
    self.conv5   = add_conv_stage(256, 512, useBN=useBN)

    self.conv4m = add_conv_stage(512, 256, useBN=useBN)
    self.conv3m = add_conv_stage(256, 128, useBN=useBN)
    self.conv2m = add_conv_stage(128,  64, useBN=useBN)
    self.conv1m = add_conv_stage( 64,  32, useBN=useBN)

    self.conv0  = nn.Sequential(
        nn.Conv2d(32, config.K, 3, 1, 1),
        nn.Sigmoid(),
        nn.Softmax2d()
    )
    self.pad = nn.ConstantPad2d(config.radius-1,0)
    self.max_pool = nn.MaxPool2d(2)

    self.upsample54 = upsample(512, 256)
    self.upsample43 = upsample(256, 128)
    self.upsample32 = upsample(128,  64)
    self.upsample21 = upsample(64 ,  32)
    ## weight initialization
    for m in self.modules():
      if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        if m.bias is not None:
          m.bias.data.zero_()
    #self.Kconst = torch.tensor(config.K).float()
    #self.cropped_seg = torch.zeros(config.BatchSize,config.K,config.inputsize[0],config.inputsize[1],(config.radius-1)*2+1,(config.radius-1)*2+1)
    #self.loss = NCutsLoss()


  def forward(self, x):#, weight):
    #sw = weight.sum(-1).sum(-1)
    conv1_out = self.conv1(x)
    #return self.upsample21(conv1_out)
    conv2_out = self.conv2(self.max_pool(conv1_out))
    conv3_out = self.conv3(self.max_pool(conv2_out))
    conv4_out = self.conv4(self.max_pool(conv3_out))
    conv5_out = self.conv5(self.max_pool(conv4_out))

    conv5m_out = torch.cat((self.upsample54(conv5_out), conv4_out), 1)
    conv4m_out = self.conv4m(conv5m_out)

    conv4m_out_ = torch.cat((self.upsample43(conv4m_out), conv3_out), 1)
    conv3m_out = self.conv3m(conv4m_out_)

    conv3m_out_ = torch.cat((self.upsample32(conv3m_out), conv2_out), 1)
    conv2m_out = self.conv2m(conv3m_out_)

    conv2m_out_ = torch.cat((self.upsample21(conv2m_out), conv1_out), 1)
    conv1m_out = self.conv1m(conv2m_out_)

    conv0_out = self.conv0(conv1m_out)
    padded_seg = self.pad(conv0_out)
    '''for m in torch.arange((config.radius-1)*2+1,dtype=torch.long):
        for n in torch.arange((config.radius-1)*2+1,dtype=torch.long):
            self.cropped_seg[:,:,:,:,m,n]=padded_seg[:,:,m:m+conv0_out.size()[2],n:n+conv0_out.size()[3]].clone()
    multi1 = self.cropped_seg.mul(weight)
    multi2 = multi1.view(multi1.shape[0],multi1.shape[1],multi1.shape[2],multi1.shape[3],-1).sum(-1).mul(conv0_out)
    multi3 = sum_weight.mul(conv0_out)
    assocA = multi2.view(multi2.shape[0],multi2.shape[1],-1).sum(-1)
    assocV = multi3.view(multi3.shape[0],multi3.shape[1],-1).sum(-1)
    assoc = assocA.div(assocV).sum(-1)
    loss = self.Kconst - assoc'''
    #loss = self.loss(conv0_out, padded_seg, weight, sw)
    return [conv0_out,padded_seg]


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as func
from torch.autograd import Function
import time
import pdb
import subprocess
import numpy as np

config = Config()

class NCutsLoss(nn.Module):
    def __init__(self):
        super(NCutsLoss,self).__init__()
        self.gpu_list = []
        '''
        for i in range(torch.cuda.device_count()):
            self.gpu_list.append(torch.cuda.device(i))
# the ratio of the free space among all gpus
        self.gpu_room_list = []
        self.gpu_room_update()
        '''
    '''def gpu_room_update(self):
        self.gpu_room_list = []
        free_memory = get_gpu_memory_map()
        total_free = 0
        count_ratio = 0.0
        for _, value in free_memory.items():
            total_free+=value
        for dev in self.gpu_list:
            ratio = float(free_memory[dev])/total_free
            self.gpu_room_list.append(ratio)
            count_ratio += ratio
        if (count_ratio - 1 < 0):
            self.gpu_room_list[-1]+=1.0-count_ratio 
    '''    
            

    def forward(self, seg, padded_seg, weight,sum_weight):
        #too many values to unpack
        cropped_seg = []
        for m in torch.arange((config.radius-1)*2+1,dtype=torch.long):
            column = []
            for n in torch.arange((config.radius-1)*2+1,dtype=torch.long):
                column.append(padded_seg[:,:,m:m+seg.size()[2],n:n+seg.size()[3]].clone())
            cropped_seg.append(torch.stack(column,4))
        cropped_seg = torch.stack(cropped_seg,4)
        #for m in torch.arange(50,70,dtype=torch.long):

        #    print(m)
        #    for n in torch.arange(50,70,dtype= torch.long):
        #        print(weight[5,0,m,n])
        multi1 = cropped_seg.mul(weight)
        multi2 = multi1.sum(-1).sum(-1).mul(seg)
        multi3 = sum_weight.mul(seg)
        #print("=============================================================================")
        #for a in [0,1]:
        #    print(multi2[5,0,a*10+50:a*10+60,50:60])
        #    print(multi2[5,0,a*10+50:a*10+60,60:70])
        assocA = multi2.view(multi2.shape[0],multi2.shape[1],-1).sum(-1)
        assocV = multi3.view(multi3.shape[0],multi3.shape[1],-1).sum(-1)
        assoc = assocA.div(assocV).sum(-1)
        
        return torch.add(-assoc,config.K)
        
    '''def crop_seg(self,seg):
        cropped_seg = torch.zeros(seg.size()[0],seg.size()[1],seg.size()[2],seg.size()[3],(config.radius-1)*2+1,(config.radius-1)*2+1)
        padding_size = (config.radius,config.radius,config.radius,config.radius)
        padded_seg = torch.nn.functional.pad(seg,padding_size)
        for m in torch.arange((config.radius-1)*2+1,dtype=torch.long):
            for n in torch.arange((config.radius-1)*2+1,dtype=torch.long):
                cropped_seg[:,:,:,:,m,n].copy_(padded_seg[:,:,m:m+seg.size()[2],n:n+seg.size()[3]])
        return cropped_seg
    
def get_gpu_memory_map():
    """Get the current gpu usage.

    Returns
    -------
    usage: dict
        Keys are device ids as integers.
        Values are memory free as integers in MB.
    """
    result = subprocess.check_output(
        [
            'nvidia-smi', '--query-gpu=memory.free',
            '--format=csv,nounits,noheader'
        ], encoding='utf-8')
    # Convert lines into a dictionary
    gpu_memory = [int(x) for x in result.strip().split('\n')]
    gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
    return gpu_memory_map
'''        




In [32]:
import torch
import numpy as np
import time
import os

config = Config()
os.environ["CUDA_VISIBLE_DEVICES"]=config.cuda_dev_list
if __name__ == '__main__':
    dataset = DataLoader(config.pascal,"train")
    dataloader = dataset.torch_loader()
    #model = torch.nn.DataParallel(Net(True))
    model = torch.nn.DataParallel(WNet())
    model.cuda()
    #model.to(device)
    model.train()
    #optimizer
    
    optimizer = torch.optim.SGD(model.parameters(),lr = config.init_lr)
    #reconstr = torch.nn.MSELoss().cuda(config.cuda_dev)
    Ncuts = NCutsLoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.lr_decay_iter, gamma=config.lr_decay)
    
    for epoch in range(config.max_iter):
        print("Epoch: "+str(epoch+1))
        scheduler.step()
        Ave_Ncuts = 0.0
        #Ave_Rec = 0.0
        t_load = 0.0
        t_forward = 0.0
        t_loss = 0.0
        t_backward = 0.0
        t_adjust = 0.0
        t_reset = 0.0
        t_inloss = 0.0
        for step,[x,w] in enumerate(dataloader):
            #NCuts Loss
            #tick = time.time()
            x = x.cuda()
            w = w.cuda()
            #for m in torch.arange(50,70,dtype=torch.long):

            #    print(m)
            #    for n in torch.arange(50,70,dtype= torch.long):
            #        print(w[5,0,m,n])
            sw = w.sum(-1).sum(-1)
            #t_load += time.time()-tick
            #tick = time.time()
            optimizer.zero_grad()
            pred,pad_pred = model(x)
            #t_forward += time.time()-tick
            #pred.cuda()
            #tick = time.time()
            ncuts_loss = Ncuts(pred,pad_pred,w,sw)
            ncuts_loss = ncuts_loss.sum()/config.BatchSize 
            #t_loss += time.time()-tick
            #tick = time.time()
            Ave_Ncuts = (Ave_Ncuts * step + ncuts_loss.item())/(step+1)
            #t_reset += time.time()-tick
            #tick = time.time()
            ncuts_loss.backward()
            #t_backward += time.time()-tick
            #tick = time.time()
            optimizer.step()
            #t_adjust += time.time()-tick
            #Reconstruction Loss
            '''pred,rec_image = model(x)
            rec_loss = reconstr(rec_image,x)
            Ave_Rec = (Ave_Rec * step + rec_loss.item())/(step+1)
            optimizer.zero_grad()
            rec_loss.backward()
            optimizer.step()'''
        #t_total = t_load+t_reset+t_forward+t_loss+t_backward+t_adjust
        print("Ncuts loss: "+str(Ave_Ncuts))#+";total time: "+str(t_total)+";forward: "+str(t_forward/t_total)+";loss: "+str(t_loss/t_total)+";backward: "+str(t_backward/t_total)+";adjust: "+str(t_adjust/t_total)+";reset&load: "+str(t_reset/t_total)+"&"+str(t_load/t_total)+"loss: "+str(t_loss)+" / "+str(t_inloss))
        #print("Reconstruction loss: "+str(Ave_Rec))
        if (epoch+1)%500 == 0:
            localtime = time.localtime(time.time())
            checkname = './checkpoints'
            if not os.path.isdir(checkname):
                os.mkdir(checkname)
            checkname+='/checkpoint'
            for i in range(1,5):
                checkname+='_'+str(localtime[i])
            checkname += '_epoch_'+str(epoch+1)
            with open(checkname,'wb') as f:
                torch.save({
                    'epoch': epoch +1,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'Ncuts': Ave_Ncuts#,
                    #'recon': Ave_Rec
                    },f)
            print(checkname+' saved')
    

['6.png', '4.png', '2.png', '10.png', '5.png', '9.png', '3.png', '7.png', '11.png', '8.png', '0.png', '1.png', '15.png', '21.png', '24.png', '18.png', '23.png', '25.png', '14.png', '17.png', '26.png', '20.png', '13.png', '27.png', '19.png', '16.png', '22.png', '12.png', '43.png', '30.png', '29.png', '32.png', '35.png', '40.png', '44.png', '37.png', '38.png', '33.png', '28.png', '36.png', '42.png', '45.png', '31.png', '41.png', '39.png', '34.png', '53.png', '63.png', '59.png', '50.png', '62.png', '51.png', '54.png', '58.png', '46.png', '61.png', '49.png', '56.png', '48.png', '60.png', '55.png', '57.png', '52.png', '47.png', '64.png', '78.png', '76.png', '66.png', '67.png', '79.png', '74.png', '75.png', '80.png', '72.png', '73.png', '65.png', '77.png', '69.png', '68.png', '70.png', '71.png', '93.png', '94.png', '90.png', '92.png', '89.png', '82.png', '88.png', '85.png', '86.png', '84.png', '95.png', '87.png', '97.png', '81.png', '83.png', '96.png', '91.png', '111.png', '107.png', '98.png

OutOfMemoryError: ignored