In [53]:
import os
from PIL import Image
import re
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from models import *
import numpy as np
import time
import cv2
from PIL import Image
import torch.utils.data as Data
import torchvision
import torchvision.transforms as transforms

In [54]:
imgDir = '/export/hdd/scratch/hchen799/INR/reconimage_new'
test_imgDir = './cifar_10_images/test_cifar10'
batch_size = 128
batch_size_test  = 100
train_nSamples = 50000
test_nSamples = 10000
init_width = 32
init_height = 32

In [55]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

In [56]:
def load_image_path(imgDir):

    all_training_files=os.walk(imgDir)
    train_files=[]
    train_imageNames=[]
    train_nSamples=0
    for path,direction,filelist in all_training_files:
        files = [file for file in filelist if os.path.isfile(os.path.join(path, file))]
        imageNames = [file.split('.')[0] for file in files if is_image_file(file)]
        files = [os.path.join(path, file) for file in files if is_image_file(file)]
        train_files.append(files)
        train_imageNames.append(imageNames)
        train_nSamples=train_nSamples+len(files)
    train_files=sum(train_files,[])
    train_imageNames=sum(train_imageNames,[])
    #print(train_imageNames)
    train_imageNames.sort(key = lambda i:int(re.match(r'(\d+)',i).group()))
    #train_imageNames.sort(key = lambda x: int(x[:-4]))
    train_image_path = []
    for i in range (len(train_imageNames)):
        string = imgDir + '/' + train_imageNames[i] + '.png'
        train_image_path.append(string)
    return train_image_path

def load_image_path_test(imgDir):

    all_training_files=os.walk(imgDir)
    train_files=[]
    train_imageNames=[]
    train_nSamples=0
    for path,direction,filelist in all_training_files:
        files = [file for file in filelist if os.path.isfile(os.path.join(path, file))]
        imageNames = [file.split('.')[0] for file in files if is_image_file(file)]
        files = [os.path.join(path, file) for file in files if is_image_file(file)]
        train_files.append(files)
        train_imageNames.append(imageNames)
        train_nSamples=train_nSamples+len(files)
    train_files=sum(train_files,[])
    train_imageNames=sum(train_imageNames,[])
    #print(train_imageNames)
    train_imageNames.sort(key = lambda i:int(re.match(r'(\d+)',i).group()))
    #train_imageNames.sort(key = lambda x: int(x[:-4]))
    train_image_path = []
    for i in range (len(train_imageNames)):
        string = imgDir + '/' + train_imageNames[i] + '.jpg'
        train_image_path.append(string)
    return train_image_path

In [57]:
train_image_path = load_image_path(imgDir)
print(train_image_path[23])

/export/hdd/scratch/hchen799/INR/reconimage_new/23.png


In [58]:
test_image_path = load_image_path_test(test_imgDir)
train_label = np.load("cifar_10_labels.npy")
test_label = np.load("cifar_10_labels_test.npy")

In [59]:
print(train_label)

[6. 9. 9. ... 9. 1. 1.]


In [60]:
class listDataset(Dataset):
    def __init__(self, files_root, target, nsamples,shape=None, shuffle=True, transform=None, target_transform=None, train=False, seen=0, batch_size=32, num_workers=0):
      
      self.image_root=files_root
      self.target=target
      self.nSamples=nsamples
      self.transform = transform
      self.target_transform = target_transform
      self.train = train
      self.shape = shape
      self.seen = seen
      self.batch_size = batch_size
      self.num_workers = num_workers
       
    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        imgpath = self.image_root[index]
        img = Image.open(imgpath).convert('RGB')
     #print(img)
        if self.shape is not None:
            img = img.resize(self.shape)
        if self.transform is not None:
            img = self.transform(img)
        label=self.target[index]
        #print(label.type)
        label = torch.from_numpy(np.array(label, dtype = np.int64))
     
        return (img, label)

In [61]:
lr = 0.01
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [62]:
print('==> Preparing data..')
transform_train = transforms.Compose([
    #transforms.ToPILImage(),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

==> Preparing data..


In [63]:
train_loader = torch.utils.data.DataLoader(
        listDataset(train_image_path, train_label, train_nSamples, shape=(init_width, init_height),
                       shuffle=False,
                       transform=transform_train, 
                       train=True, 
                       seen=0,
                       batch_size=batch_size,
                       num_workers=0),
        batch_size=batch_size, shuffle=True, num_workers=8)

In [64]:
test_loader = torch.utils.data.DataLoader(
        listDataset(test_image_path, test_label, test_nSamples, shape=(init_width, init_height),
                       shuffle=False,
                       transform=transform_test, 
                       train=False, 
                       seen=0,
                       batch_size=batch_size,
                       num_workers=0),
        batch_size=batch_size_test, shuffle=False, num_workers=8)

In [65]:
net = ResNet18()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
    
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [66]:
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    iter = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        iter = iter + 1
        #progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     #% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    print('Loss: %.8f | Acc: %.8f%% (%d/%d)'% (train_loss/(iter * 128), 100.*correct/total, correct, total))
    
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    iter = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            iter = iter + 1
            #progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         #% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    print('Loss: %.8f | Acc: %.8f%% (%d/%d)'% (test_loss/(iter * 100), 100.*correct/total, correct, total))
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt_recon1.pth')
        best_acc = acc
        
for epoch in range(start_epoch, start_epoch+200):
    #lr = adjust_learning_rate(epoch)
    start_time = time.time()
    train(epoch)
    test(epoch)
    end_time = time.time()
    print("the time for one epoch is:", end_time - start_time)
    scheduler.step()



Epoch: 0
Loss: 0.01217626 | Acc: 42.65200000% (21326/50000)
Loss: 0.01275571 | Acc: 55.58000000% (5558/10000)
Saving..
the time for one epoch is: 17.547449111938477

Epoch: 1
Loss: 0.00866993 | Acc: 60.70000000% (30350/50000)
Loss: 0.00885093 | Acc: 68.70000000% (6870/10000)
Saving..
the time for one epoch is: 17.458316802978516

Epoch: 2
Loss: 0.00723423 | Acc: 67.54800000% (33774/50000)
Loss: 0.00906266 | Acc: 69.43000000% (6943/10000)
Saving..
the time for one epoch is: 17.3698570728302

Epoch: 3
Loss: 0.00647088 | Acc: 71.17600000% (35588/50000)
Loss: 0.00712942 | Acc: 75.43000000% (7543/10000)
Saving..
the time for one epoch is: 17.1457200050354

Epoch: 4
Loss: 0.00596952 | Acc: 73.30800000% (36654/50000)
Loss: 0.00682529 | Acc: 76.40000000% (7640/10000)
Saving..
the time for one epoch is: 17.375280141830444

Epoch: 5
Loss: 0.00553602 | Acc: 75.29000000% (37645/50000)
Loss: 0.00860190 | Acc: 73.01000000% (7301/10000)
the time for one epoch is: 17.149155616760254

Epoch: 6
Loss: 0

In [269]:
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    iter = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        iter = iter + 1
        #progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     #% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    print('Loss: %.8f | Acc: %.8f%% (%d/%d)'% (train_loss/(iter * 128), 100.*correct/total, correct, total))
    
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    iter = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            iter = iter + 1
            #progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         #% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    print('Loss: %.8f | Acc: %.8f%% (%d/%d)'% (test_loss/(iter * 100), 100.*correct/total, correct, total))
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt1.pth')
        best_acc = acc
        
for epoch in range(start_epoch, start_epoch+200):
    #lr = adjust_learning_rate(epoch)
    start_time = time.
    train(epoch)
    test(epoch)
    scheduler.step()



Epoch: 0
Loss: 0.01152857 | Acc: 46.45800000% (23229/50000)
Loss: 0.01434262 | Acc: 55.00000000% (5500/10000)
Saving..

Epoch: 1
Loss: 0.00766150 | Acc: 65.20400000% (32602/50000)
Loss: 0.00977555 | Acc: 68.67000000% (6867/10000)
Saving..

Epoch: 2
Loss: 0.00596879 | Acc: 73.44600000% (36723/50000)
Loss: 0.00718585 | Acc: 75.39000000% (7539/10000)
Saving..

Epoch: 3
Loss: 0.00498961 | Acc: 77.97800000% (38989/50000)
Loss: 0.00678254 | Acc: 78.19000000% (7819/10000)
Saving..

Epoch: 4
Loss: 0.00442079 | Acc: 80.37800000% (40189/50000)
Loss: 0.00579626 | Acc: 80.19000000% (8019/10000)
Saving..

Epoch: 5
Loss: 0.00400438 | Acc: 82.30600000% (41153/50000)
Loss: 0.00574454 | Acc: 81.09000000% (8109/10000)
Saving..

Epoch: 6
Loss: 0.00364271 | Acc: 83.89400000% (41947/50000)
Loss: 0.00521002 | Acc: 83.03000000% (8303/10000)
Saving..

Epoch: 7
Loss: 0.00334364 | Acc: 85.04600000% (42523/50000)
Loss: 0.00461823 | Acc: 84.56000000% (8456/10000)
Saving..

Epoch: 8
Loss: 0.00314129 | Acc: 86.144