In [None]:
import torch
import torchvision
import numpy as np
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import datetime
import PIL
import os
import time
import argparse
import matplotlib.pyplot as plt
from CycleGAN import CycleGAN
from ImageDataset import ImageDataset
from torch.utils.data import DataLoader
import argparse

## Optional

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
parser.add_argument('--dataroot', type=str, default='../datasets/monet2photo/', help='root directory of the dataset')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate')
parser.add_argument('--size', type=int, default=256, help='size of the data crop (squared assumed)')
parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
parser.add_argument('--cuda', action='store_true', help='use GPU computation')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
opt = parser.parse_args(args=[])
parser.add_argument('--batchSize', type=int, default=1, help='batch size')
parser.add_argument('--epochs', type=int, default=200, help='number of epochs')
parser.add_argument('--lambda', type=float, default=10, help='weighr for cycle consistency loss')

_StoreAction(option_strings=['--lambda'], dest='lambda', nargs=None, const=None, default=10, type=<class 'float'>, choices=None, help='weighr for cycle consistency loss', metavar=None)

### Hyperparameters

In [None]:
LR = 0.001
batch_size = 200
num_epochs = 10

### Helper Methods

In [None]:
def plot_graph(num_epochs,acc_list,loss_list):
    #usage : plot_graph(num_epochs,acc_list,loss_list)
    plt.ioff()
    fig = plt.figure()
    plt.subplot(2, 1, 1)
    plt.ylabel('Training loss')
    plt.plot(np.arange(num_epochs), loss_list, 'k-')
    plt.title('Training Loss and Training Accuracy')
    plt.xticks(np.arange(num_epochs, dtype=int))
    plt.grid(True)

    plt.subplot(2, 1, 2)
    plt.plot(np.arange(num_epochs), acc_list, 'b-')
    plt.ylabel('Training Accuracy')
    plt.xlabel('Epochs')
    plt.xticks(np.arange(num_epochs, dtype=int))
    plt.grid(True)
    plt.savefig("plot.png")
    plt.close(fig)

# Main

### Load DataSet

In [None]:
#256 x 256
transform_train = [transforms.RandomCrop(256, padding=4),transforms.RandomHorizontalFlip(p=2),
                   transforms.ToTensor(),
                   transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]
transform_test = [transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]
trainset = ImageDataset('../datasets/monet2photo/', transforms_=transform_train,mode='train')
testset = ImageDataset('../datasets/monet2photo/', transforms_=transform_train,mode='test')
train_loader = DataLoader(trainset,batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset,batch_size=batch_size, shuffle=True)
print(len(trainset))
print(len(testset))

### Declaration

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = CycleGAN('opt').to(device)
# scheduler = lr_scheduler.StepLR(optimizer, step_size=scheduler_step_size, gamma=0.1)

### Training

In [None]:
acc_list = []
loss_list = []
for epoch in range(1,num_epochs+1):
    #model.train()
    start_time = time.time()
    running_loss = 0.0
    acc = 0.0
    print("epoch {}/{}".format(epoch,num_epochs))
    for batch_idx, data in enumerate(train_loader):
        #print(data)
        A = data['A'].to(device)
        B = data['B'].to(device)
        
        #optimizer.zero_grad()
        #outputs = model(images)
#         model.load(img_A, imgB)
#         model.optimize_parameters()
        #loss = criterion(outputs,labels)
        #_,preds = torch.max(outputs.data,1)
        #loss.backward()
        #optimizer.step()
        #running_loss+=loss.item()
        #acc+=torch.sum(preds == labels).item()
    end_time = time.time()
    print('Training Time: ',end_time-start_time ,'s, Training accurarcy: ',acc/len(trainset),', Training loss: ',running_loss/len(trainset))
    
    correct = 0
    with torch.no_grad():
        #model.eval()
        start_time = time.time()
        for batch_idx, (A,B) in enumerate(test_loader):
            A = A.to(device)
            B = B.to(device)
            #model.load(A,B)
            #model.optimize_parameters()
            #_,predicted = torch.max(outputs.data,1)
            #correct+=torch.sum(predicted==labels).item()
        end_time = time.time()
        print('Testing Time: ',end_time-start_time ,'s, Testing Accurarcy: ',correct/len(testset))
    print('-' * 20)

    

In [None]:
plot_graph(num_epochs,acc_list,loss_list)

In [None]:
#torch.save(model,'cycleGAN.model')